61 explicit Tensor(
int i)
63 resizeI<std::vector<int>>({i});
68 resizeI<std::vector<int>>({i, j});
71 Tensor(
int i,
int j,
int k)
73 resizeI<std::vector<int>>({i, j, k});
76 Tensor(
int i,
int j,
int k,
int l)
78 resizeI<std::vector<int>>({i, j, k, l});
81 template <
typename Sizes>
82 void resizeI(
const Sizes& sizes)
84 if (sizes.size() == 1)
85 dims_ = {(int)sizes[0]};
86 if (sizes.size() == 2)
87 dims_ = {(int)sizes[0], (
int)sizes[1]};
88 if (sizes.size() == 3)
89 dims_ = {(int)sizes[0], (
int)sizes[1], (int)sizes[2]};
90 if (sizes.size() == 4)
91 dims_ = {(int)sizes[0], (
int)sizes[1], (int)sizes[2], (
int)sizes[3]};
93 data_.resize(std::accumulate(begin(dims_), end(dims_), 1.0, std::multiplies<>()));
98 OPM_ERROR_IF(dims_.size() == 0,
"Invalid tensor");
100 int elements = dims_[0];
101 for (
unsigned int i = 1; i < dims_.size(); i++) {
102 elements *= dims_[i];
109 OPM_ERROR_IF(dims_.size() != 1,
"Invalid indexing for tensor");
111 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
112 fmt::format(
" Invalid i: "
122 T& operator()(
int i,
int j)
124 OPM_ERROR_IF(dims_.size() != 2,
"Invalid indexing for tensor");
125 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
126 fmt::format(
" Invalid i: "
132 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
133 fmt::format(
" Invalid j: "
140 return data_[dims_[1] * i + j];
143 const T& operator()(
int i,
int j)
const
145 OPM_ERROR_IF(dims_.size() != 2,
"Invalid indexing for tensor");
146 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
147 fmt::format(
" Invalid i: "
153 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
154 fmt::format(
" Invalid j: "
160 return data_[dims_[1] * i + j];
163 T& operator()(
int i,
int j,
int k)
165 OPM_ERROR_IF(dims_.size() != 3,
"Invalid indexing for tensor");
166 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
167 fmt::format(
" Invalid i: "
173 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
174 fmt::format(
" Invalid j: "
180 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
181 fmt::format(
" Invalid k: "
188 return data_[dims_[2] * (dims_[1] * i + j) + k];
191 const T& operator()(
int i,
int j,
int k)
const
193 OPM_ERROR_IF(dims_.size() != 3,
"Invalid indexing for tensor");
194 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
195 fmt::format(
" Invalid i: "
201 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
202 fmt::format(
" Invalid j: "
208 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
209 fmt::format(
" Invalid k: "
216 return data_[dims_[2] * (dims_[1] * i + j) + k];
219 T& operator()(
int i,
int j,
int k,
int l)
221 OPM_ERROR_IF(dims_.size() != 4,
"Invalid indexing for tensor");
222 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
223 fmt::format(
" Invalid i: "
229 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
230 fmt::format(
" Invalid j: "
236 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
237 fmt::format(
" Invalid k: "
243 OPM_ERROR_IF(!(l < dims_[3] && l >= 0),
244 fmt::format(
" Invalid l: "
251 return data_[dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l];
254 const T& operator()(
int i,
int j,
int k,
int l)
const
256 OPM_ERROR_IF(dims_.size() != 4,
"Invalid indexing for tensor");
257 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
258 fmt::format(
" Invalid i: "
264 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
265 fmt::format(
" Invalid j: "
271 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
272 fmt::format(
" Invalid k: "
278 OPM_ERROR_IF(!(l < dims_[3] && l >= 0),
279 fmt::format(
" Invalid l: "
286 return data_[dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l];
289 void fill(
const T& value)
291 std::fill(data_.begin(), data_.end(), value);
295 Tensor operator+(
const Tensor& other)
297 OPM_ERROR_IF(dims_.size() != other.dims_.size(),
298 "Cannot add tensors with different dimensions");
300 result.dims_ = dims_;
301 result.data_.resize(data_.size());
303 std::transform(data_.begin(),
306 result.data_.begin(),
307 [](
const T& x,
const T& y) { return x + y; });
313 Tensor multiply(
const Tensor& other)
315 OPM_ERROR_IF(dims_.size() != other.dims_.size(),
316 "Cannot multiply elements with different dimensions");
319 result.dims_ = dims_;
320 result.data_.resize(data_.size());
322 std::transform(data_.begin(),
325 result.data_.begin(),
326 [](
const T& x,
const T& y) { return x * y; });
332 Tensor dot(
const Tensor& other)
334 OPM_ERROR_IF(dims_.size() != 2,
"Invalid tensor dimensions");
335 OPM_ERROR_IF(other.dims_.size() != 2,
"Invalid tensor dimensions");
337 OPM_ERROR_IF(dims_[1] != other.dims_[0],
338 "Cannot multiply with different inner dimensions");
340 Tensor tmp(dims_[0], other.dims_[1]);
342 for (
int i = 0; i < dims_[0]; i++) {
343 for (
int j = 0; j < other.dims_[1]; j++) {
344 for (
int k = 0; k < dims_[1]; k++) {
345 tmp(i, j) += (*this)(i, k) * other(k, j);
353 void swap(Tensor& other)
355 dims_.swap(other.dims_);
356 data_.swap(other.data_);
359 std::vector<int> dims_;
360 std::vector<T> data_;