-
Notifications
You must be signed in to change notification settings - Fork 59
Expand file tree
/
Copy pathtensor.h
More file actions
3532 lines (3154 loc) · 138 KB
/
tensor.h
File metadata and controls
3532 lines (3154 loc) · 138 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* This file is a part of TiledArray.
* Copyright (C) 2013 Virginia Tech
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/
#ifndef TILEDARRAY_TENSOR_TENSOR_H__INCLUDED
#define TILEDARRAY_TENSOR_TENSOR_H__INCLUDED
#include "TiledArray/config.h"
#include "TiledArray/host/env.h"
#include "TiledArray/platform.h"
#include "TiledArray/math/blas.h"
#include "TiledArray/math/gemm_helper.h"
#include "TiledArray/tensor/arena_kernels.h"
#include "TiledArray/tensor/complex.h"
#include "TiledArray/tensor/kernels.h"
#include "TiledArray/tile_interface/clone.h"
#include "TiledArray/tile_interface/permute.h"
#include "TiledArray/tile_interface/trace.h"
#include "TiledArray/util/logger.h"
#include "TiledArray/util/ptr_registry.h"
#include <umpire_cxx_allocator.hpp>
namespace TiledArray {
namespace detail {
/// Signals that we can take the trace of a Tensor<T, A> (for numeric \c T)
template <typename T, typename A>
struct TraceIsDefined<Tensor<T, A>, enable_if_numeric_t<T>> : std::true_type {};
template <typename To, typename From,
typename = std::enable_if_t<
detail::is_nested_tensor_v<To, detail::remove_cvr_t<From>>>>
To clone_or_cast(From&& f) {
if constexpr (std::is_same_v<To, detail::remove_cvr_t<From>>)
return std::forward<From>(f).clone();
else if constexpr (detail::is_convertible_v<From, To>) {
return static_cast<To>(std::forward<From>(f));
} else if constexpr (detail::is_range_v<To> &&
detail::is_range_v<detail::remove_cvr_t<From>>) {
using std::begin;
using std::data;
using std::end;
To t(f.range());
if constexpr (detail::is_contiguous_tensor_v<detail::remove_cvr_t<From>>) {
const auto n = f.range().volume();
if constexpr (detail::is_contiguous_tensor_v<To>) {
std::copy(data(f), data(f) + n, data(t));
} else {
std::copy(data(f), data(f) + n, begin(t));
}
} else {
if constexpr (detail::is_contiguous_tensor_v<To>) {
std::copy(begin(f), end(f), data(t));
} else
std::copy(begin(f), end(f), begin(t));
}
return t;
} else {
static_assert(
!std::is_void_v<To>,
"clone_or_cast<To,From>: could not figure out how to convert From to "
"To, either overload of a member function of Tensor is missing or From "
"need to provide a conversion operator to To");
}
}
} // namespace detail
/// An N-dimensional tensor object
/// A contiguous row-major tensor with __shallow-copy__ semantics.
/// As of TiledArray 1.1 Tensor represents a batch of tensors with same Range
/// (the default batch size = 1).
/// \tparam T The value type of this tensor
/// \tparam A The allocator type for the data; only default-constructible
/// allocators are supported to save space
template <typename T, typename Allocator>
class Tensor {
// meaningful error if T& is not assignable, see
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=48101
static_assert(std::is_assignable<std::add_lvalue_reference_t<T>, T>::value,
"Tensor<T,Allocator>: T must be an assignable type (e.g. "
"cannot be const)");
// default-constructible Allocator allows to reduce the size of default Tensor
// and minimize the overhead of null elements in Tensor<Tensor<T>>
static_assert(
std::is_default_constructible_v<Allocator>,
"Tensor<T,Allocator>: only default-constructible Allocator is supported");
#ifdef TA_TENSOR_MEM_TRACE
template <typename... Ts>
std::string make_string(Ts&&... ts) {
std::ostringstream oss;
(oss << ... << ts);
return oss.str();
}
#endif
public:
typedef Range range_type; ///< Tensor range type
typedef typename range_type::index1_type index1_type; ///< 1-index type
typedef typename range_type::ordinal_type ordinal_type; ///< Ordinal type
typedef typename range_type::ordinal_type
size_type; ///< Size type (to meet the container concept)
typedef Allocator allocator_type; ///< Allocator type
typedef typename std::allocator_traits<allocator_type>::value_type
value_type; ///< Array element type
typedef std::add_lvalue_reference_t<value_type>
reference; ///< Element (lvalue) reference type
typedef std::add_lvalue_reference_t<std::add_const_t<value_type>>
const_reference; ///< Element (const lvalue) reference type
typedef typename std::allocator_traits<allocator_type>::pointer
pointer; ///< Element pointer type
typedef typename std::allocator_traits<allocator_type>::const_pointer
const_pointer; ///< Element const pointer type
typedef typename std::allocator_traits<allocator_type>::difference_type
difference_type; ///< Difference type
typedef pointer iterator; ///< Element iterator type
typedef const_pointer const_iterator; ///< Element const iterator type
typedef typename TiledArray::detail::numeric_type<T>::type
numeric_type; ///< the numeric type that supports T
typedef typename TiledArray::detail::scalar_type<T>::type
scalar_type; ///< the scalar type that supports T
private:
template <typename X>
using value_t = typename X::value_type;
template <typename X>
using numeric_t = typename TiledArray::detail::numeric_type<X>::type;
template <typename... Ts>
struct is_tensor {
static constexpr bool value = detail::is_tensor<Ts...>::value ||
detail::is_tensor_of_tensor<Ts...>::value;
};
public:
/// compute type of Tensor with different element type
template <typename U,
typename OtherAllocator = typename std::allocator_traits<
Allocator>::template rebind_alloc<U>>
using rebind_t = Tensor<U, OtherAllocator>;
template <typename U, typename V = value_type, typename = void>
struct rebind_numeric;
template <typename U, typename V>
struct rebind_numeric<U, V, std::enable_if_t<is_tensor<V>::value>> {
using VU = typename V::template rebind_numeric<U>::type;
using type = Tensor<VU, typename std::allocator_traits<
Allocator>::template rebind_alloc<VU>>;
};
template <typename U, typename V>
struct rebind_numeric<U, V, std::enable_if_t<!is_tensor<V>::value>> {
using type = Tensor<
U, typename std::allocator_traits<Allocator>::template rebind_alloc<U>>;
};
/// compute type of Tensor with different numeric type
template <typename U>
using rebind_numeric_t = typename rebind_numeric<U, value_type>::type;
private:
using default_construct = bool;
Tensor(const range_type& range, size_t nbatch, bool default_construct)
: range_(range), nbatch_(nbatch) {
size_t size = range_.volume() * nbatch;
allocator_type allocator;
auto* ptr = allocator.allocate(size);
// default construct elements of data only if can have any effect ...
if constexpr (!std::is_trivially_default_constructible_v<T>) {
// .. and requested
if (default_construct) {
std::uninitialized_default_construct_n(ptr, size);
}
}
auto deleter = [
#ifdef TA_TENSOR_MEM_TRACE
this,
#endif
allocator = std::move(allocator),
size](auto&& ptr) mutable {
std::destroy_n(ptr, size);
// N.B. deregister ptr *before* deallocating to avoid possible race
// between reallocation and deregistering
#ifdef TA_TENSOR_MEM_TRACE
const auto nbytes = size * sizeof(T);
if (nbytes >= trace_if_larger_than_) {
ptr_registry()->erase(ptr, nbytes,
make_string("created by TA::Tensor*=", this));
}
#endif
allocator.deallocate(ptr, size);
};
this->data_ = std::shared_ptr<value_type[]>(ptr, std::move(deleter));
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->insert(
this, make_string("TA::Tensor::data_.get()=", data_.get()));
ptr_registry()->insert(data_.get(), nbytes(),
make_string("created by TA::Tensor*=", this));
}
#endif
}
Tensor(range_type&& range, size_t nbatch, bool default_construct)
: range_(std::move(range)), nbatch_(nbatch) {
size_t size = range_.volume() * nbatch;
allocator_type allocator;
auto* ptr = allocator.allocate(size);
// default construct elements of data only if can have any effect ...
if constexpr (!std::is_trivially_default_constructible_v<T>) {
// .. and requested
if (default_construct) {
std::uninitialized_default_construct_n(ptr, size);
}
}
auto deleter = [
#ifdef TA_TENSOR_MEM_TRACE
this,
#endif
allocator = std::move(allocator),
size](auto&& ptr) mutable {
std::destroy_n(ptr, size);
// N.B. deregister ptr *before* deallocating to avoid possible race
// between reallocation and deregistering
#ifdef TA_TENSOR_MEM_TRACE
const auto nbytes = size * sizeof(T);
if (nbytes >= trace_if_larger_than_) {
ptr_registry()->erase(ptr, nbytes,
make_string("created by TA::Tensor*=", this));
}
#endif
allocator.deallocate(ptr, size);
};
this->data_ = std::shared_ptr<value_type[]>(ptr, std::move(deleter));
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->insert(
this, make_string("TA::Tensor::data_.get()=", data_.get()));
ptr_registry()->insert(data_.get(), nbytes(),
make_string("created by TA::Tensor*=", this));
}
#endif
}
template <typename T_>
static decltype(auto) value_converter(const T_& arg) {
using arg_type = detail::remove_cvr_t<decltype(arg)>;
if constexpr (detail::is_tensor_v<arg_type> &&
!is_tensor_view_v<arg_type>) // clone owning nested tensors
return arg.clone();
else if constexpr (!std::is_same_v<arg_type, value_type>) { // convert
if constexpr (std::is_convertible_v<arg_type, value_type>)
return static_cast<value_type>(arg);
else
return conversions::to<value_type, arg_type>()(arg);
} else
return arg; // identity (for views, copy = rebind, no deep clone)
};
range_type range_; ///< Range
/// Number of `range_`-sized blocks in `data_`
/// \note this is not used for (in)equality comparison
size_t nbatch_ = 1;
std::shared_ptr<value_type[]> data_; ///< Shared pointer to the data
public:
/// constructs an empty (null) Tensor
/// \post `this->empty()`
Tensor() = default;
/// copy constructor
/// \param[in] other an object to copy data from
/// \post `*this` is a shallow copy of \p other ,
/// i.e. `*this == other && this->data()==other.data()`
Tensor(const Tensor& other)
: range_(other.range_), nbatch_(other.nbatch_), data_(other.data_) {
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->insert(
this, make_string("TA::Tensor(const Tensor& other)::data_.get()=",
data_.get()));
}
#endif
}
/// move constructor
/// \param[in,out] other an object to move data from;
/// on return \p other is in empty (null) but not
/// necessarily default state
/// \post `other.empty()`
Tensor(Tensor&& other)
: range_(std::move(other.range_)),
nbatch_(std::move(other.nbatch_)),
data_(std::move(other.data_)) {
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->erase(
&other,
make_string("TA::Tensor(Tensor&& other)::data_.get()=", data_.get()));
ptr_registry()->insert(
this,
make_string("TA::Tensor(Tensor&& other)::data_.get()=", data_.get()));
}
#endif
}
~Tensor() {
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->erase(
this, make_string("TA::~Tensor()::data_.get()=", data_.get()));
}
#endif
}
struct nbatches {
template <typename Int,
typename = std::enable_if_t<std::is_integral_v<Int>>>
nbatches(Int n) : n(n) {}
template <typename Int,
typename = std::enable_if_t<std::is_integral_v<Int>>>
nbatches& operator=(Int n) {
this->n = n;
}
size_type n = 1;
};
/// Construct a tensor with a range equal to \c range. The data is
/// default-initialized (which, for `T` with trivial default constructor,
/// means data is uninitialized).
/// \param range The range of the tensor
/// \param nbatch The number of batches (default is 1)
explicit Tensor(const range_type& range, nbatches nb = 1)
: Tensor(range, nb.n, default_construct{true}) {}
/// Construct a tensor of tensor values, setting all elements to the same
/// value
/// \param range An array with the size of of each dimension
/// \param value The value of the tensor elements
template <
typename Value,
typename std::enable_if<std::is_same<Value, value_type>::value &&
detail::is_tensor<Value>::value>::type* = nullptr>
Tensor(const range_type& range, const Value& value)
: Tensor(range, 1, default_construct{false}) {
const auto n = this->size();
pointer MADNESS_RESTRICT const data = this->data();
if constexpr (is_tensor_view_v<Value>) {
// Views are rebind-on-copy and lack member `clone`; just copy each.
for (size_type i = 0ul; i < n; ++i) new (data + i) value_type(value);
} else {
Clone<Value, Value> cloner;
for (size_type i = 0ul; i < n; ++i)
new (data + i) value_type(cloner(value));
}
}
/// Construct a tensor of scalars, setting all elements to the same value
/// \param range An array with the size of of each dimension
/// \param value The value of the tensor elements
template <typename Value,
typename std::enable_if<std::is_convertible_v<Value, value_type> &&
!detail::is_tensor<Value>::value>::type* =
nullptr>
Tensor(const range_type& range, const Value& value)
: Tensor(range, 1, default_construct{false}) {
detail::tensor_init([value]() -> Value { return value; }, *this);
}
/// Construct a tensor with a fill op that takes an element index
/// \tparam ElementIndexOp callable of signature
/// `value_type(const Range::index_type&)`
/// \param range An array with the size of of each dimension
/// \param element_idx_op a callable of type ElementIndexOp
template <typename ElementIndexOp,
typename = std::enable_if_t<std::is_invocable_r_v<
value_type, ElementIndexOp, const Range::index_type&>>>
Tensor(const range_type& range, const ElementIndexOp& element_idx_op)
: Tensor(range, 1, default_construct{false}) {
pointer MADNESS_RESTRICT const data = this->data();
for (auto&& element_idx : range) {
const auto ord = range.ordinal(element_idx);
new (data + ord) value_type(element_idx_op(element_idx));
}
}
/// Construct an evaluated tensor
template <typename InIter,
typename std::enable_if<
TiledArray::detail::is_input_iterator<InIter>::value &&
!std::is_pointer<InIter>::value>::type* = nullptr>
Tensor(const range_type& range, InIter it)
: Tensor(range, 1, default_construct{false}) {
auto n = range.volume();
pointer MADNESS_RESTRICT data = this->data();
for (size_type i = 0ul; i < n; ++i, ++it, ++data)
new (data) value_type(*it);
}
template <typename U>
Tensor(const Range& range, const U* u)
: Tensor(range, 1, default_construct{false}) {
math::uninitialized_copy_vector(range.volume(), u, this->data());
}
explicit Tensor(const Range& range, std::initializer_list<T> il)
: Tensor(range, il.begin()) {}
/// Construct a copy of a tensor interface object
/// \tparam T1 A tensor type
/// \param other The tensor to be copied
/// \note this constructor is disabled if \p T1 already has a conversion
/// operator to this type
/// \warning if `T1` is a tensor of tensors its elements are _cloned_ rather
/// than copied to make the semantics of this to be consistent
/// between tensors of scalars and tensors of scalars; specifically,
/// if `T1` is a tensor of scalars the constructed tensor is
/// is independent of \p other, thus should apply clone to inner
/// tensor nests to behave similarly for nested tensors
template <
typename T1,
typename std::enable_if<
is_tensor<T1>::value && !std::is_same<T1, Tensor>::value &&
!detail::has_conversion_operator_v<T1, Tensor>>::type* = nullptr>
explicit Tensor(const T1& other)
: Tensor(detail::clone_range(other), 1, default_construct{false}) {
detail::tensor_init(value_converter<typename T1::value_type>, *this, other);
}
/// Construct a permuted tensor copy
/// \tparam T1 A tensor type
/// \tparam Perm A permutation type
/// \param other The tensor to be copied
/// \param perm The permutation that will be applied to the copy
/// \warning if `T1` is a tensor of tensors its elements are _cloned_ rather
/// than copied to make the semantics of this to be consistent
/// between tensors of scalars and tensors of tensors; specifically,
/// if `T1` is a tensor of scalars the constructed tensor is
/// is independent of \p other, thus should apply clone to inner
/// tensor nests to behave similarly for nested tensors
template <
typename T1, typename Perm,
typename std::enable_if<detail::is_nested_tensor_v<T1> &&
detail::is_permutation_v<Perm>>::type* = nullptr>
Tensor(const T1& other, const Perm& perm)
: Tensor(outer(perm) * other.range(), other.nbatch(),
default_construct{false}) {
const auto outer_perm = outer(perm);
if (outer_perm) {
detail::tensor_init(value_converter<typename T1::value_type>, outer_perm,
*this, other);
} else {
detail::tensor_init(value_converter<typename T1::value_type>, *this,
other);
}
// If we actually have a ToT the inner permutation was not applied above so
// we do that now
constexpr bool is_tot = detail::is_tensor_of_tensor_v<Tensor>;
constexpr bool is_bperm = detail::is_bipartite_permutation_v<Perm>;
constexpr bool is_view = is_tensor_view_v<value_type>;
// tile ops pass bipartite permutations here even if this is a plain tensor.
// For view inners, the cell has fixed layout that can't be permuted in
// place -- skip the inner-permute pass and rely on callers to arrange
// canonical inner indexing (regime-A einsum's `do_perm.{A,B,C}` bailout
// guarantees no inner permutation is needed for our paths).
if constexpr (is_tot && is_bperm && !is_view) {
if (inner_size(perm) != 0) {
const auto inner_perm = inner(perm);
Permute<value_type, value_type> p;
auto volume = total_size();
for (decltype(volume) i = 0; i < volume; ++i) {
auto& el = *(data() + i);
if (!el.empty()) el = p(el, inner_perm);
}
}
} else if constexpr (is_tot && is_bperm && is_view) {
if (inner_size(perm) != 0) {
TA_EXCEPTION(
"Tensor<View>: inner permutation requested but view "
"cells cannot be permuted in place");
}
}
}
/// "Element-wise" unary transform of \c other
/// \tparam T1 A tensor type
/// \tparam Op A unary callable
/// \param other The tensor argument
/// \param op Unary operation that can be invoked on elements of \p other ;
/// if it is not, it will be "threaded" over \p other via `tensor_op`
template <typename T1, typename Op,
typename std::enable_if_t<
is_tensor<T1>::value &&
!detail::is_permutation_v<std::decay_t<Op>>>* = nullptr>
Tensor(const T1& other, Op&& op)
: Tensor(detail::clone_range(other), 1, default_construct{false}) {
detail::tensor_init(op, *this, other);
}
/// "Element-wise" unary transform of \c other fused with permutation
/// equivalent, but more efficient, than `Tensor(other, op).permute(perm)`
/// \tparam T1 A tensor type
/// \tparam Op A unary callable
/// \tparam Perm A permutation type
/// \param other The tensor argument
/// \param op Unary operation that can be invoked as` op(other[i]))`;
/// if it is not, it will be "threaded" over \p other via `tensor_op`
template <
typename T1, typename Op, typename Perm,
typename std::enable_if_t<is_tensor<T1>::value &&
detail::is_permutation_v<Perm>>* = nullptr>
Tensor(const T1& other, Op&& op, const Perm& perm)
: Tensor(outer(perm) * other.range(), 1, default_construct{false}) {
detail::tensor_init(op, outer(perm), *this, other);
// If we actually have a ToT the inner permutation was not applied above so
// we do that now
constexpr bool is_tot = detail::is_tensor_of_tensor_v<Tensor>;
// tile ops pass bipartite permutations here even if this is a plain tensor
constexpr bool is_bperm = detail::is_bipartite_permutation_v<Perm>;
if constexpr (is_tot && is_bperm) {
if (inner_size(perm) != 0) {
auto inner_perm = inner(perm);
Permute<value_type, value_type> p;
for (auto& x : *this) x = p(x, inner_perm);
}
}
}
/// "Element-wise" binary transform of \c {left,right}
/// \tparam T1 A tensor type
/// \tparam T2 A tensor type
/// \tparam Op A binary callable
/// \param left The left-hand tensor argument
/// \param right The right-hand tensor argument
/// \param op Binary operation that can be invoked as `op(left[i],right[i]))`;
/// if it is not, it will be "threaded" over \p other via `tensor_op`
template <typename T1, typename T2, typename Op,
typename = std::enable_if_t<detail::is_nested_tensor_v<T1, T2>>>
Tensor(const T1& left, const T2& right, Op&& op)
: Tensor(detail::clone_range(left), 1, default_construct{false}) {
detail::tensor_init(op, *this, left, right);
}
/// "Element-wise" binary transform of \c {left,right} fused with permutation
/// \tparam T1 A tensor type
/// \tparam T2 A tensor type
/// \tparam Op A binary callable
/// \tparam Perm A permutation tile
/// \param left The left-hand tensor argument
/// \param right The right-hand tensor argument
/// \param op Binary operation that can be invoked as `op(left[i],right[i]))`;
/// if it is not, it will be "threaded" over \p other via `tensor_op`
/// \param perm The permutation that will be applied to the arguments
template <
typename T1, typename T2, typename Op, typename Perm,
typename std::enable_if<detail::is_nested_tensor<T1, T2>::value &&
detail::is_permutation_v<Perm>>::type* = nullptr>
Tensor(const T1& left, const T2& right, Op&& op, const Perm& perm)
: Tensor(outer(perm) * left.range(), 1, default_construct{false}) {
detail::tensor_init(op, outer(perm), *this, left, right);
// If we actually have a ToT the inner permutation was not applied above so
// we do that now
constexpr bool is_tot = detail::is_tensor_of_tensor_v<Tensor>;
// tile ops pass bipartite permutations here even if this is a plain tensor
constexpr bool is_bperm = detail::is_bipartite_permutation_v<Perm>;
if constexpr (is_tot && is_bperm) {
if (inner_size(perm) != 0) {
auto inner_perm = inner(perm);
Permute<value_type, value_type> p;
for (auto& x : *this) x = p(x, inner_perm);
}
}
}
/// Construct a tensor with a range equal to \c range using existing data
/// \param range The range of the tensor
/// \param nbatch The number of batches
/// \param data shared pointer to the data
Tensor(const range_type& range, size_t nbatch,
std::shared_ptr<value_type[]> data)
: range_(range), nbatch_(nbatch), data_(std::move(data)) {
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->insert(
this, make_string("TA::Tensor(range, nbatch, data)::data_.get()=",
data_.get()));
}
#endif
}
/// Construct a tensor with a range equal to \c range using existing data
/// assuming unit batch size \param range The range of the tensor \param data
/// shared pointer to the data
Tensor(const range_type& range, std::shared_ptr<value_type[]> data)
: range_(range), nbatch_(1), data_(std::move(data)) {
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->insert(
this,
make_string("TA::Tensor(range, data)::data_.get()=", data_.get()));
}
#endif
}
/// The batch size accessor
/// @return the size of tensor batch represented by `*this`
size_t nbatch() const { return this->nbatch_; }
/// @param[in] idx the batch index
/// @pre `idx < this->nbatch()`
/// @return (plain, i.e. nbatch=1) Tensor representing element \p idx of
/// the batch
Tensor batch(size_t idx) const {
TA_ASSERT(idx < this->nbatch());
std::shared_ptr<value_type[]> data(this->data_,
this->data_.get() + idx * this->size());
return Tensor(this->range(), 1, data);
}
/// Returns Tensor representing the data using another range and batch size
/// @param[in] range the Range of the result
/// @param[in] nbatch the number of batches of the result
/// @return Tensor object representing `this->data()` using @p range and @p
/// nbatch
auto reshape(const range_type& range, size_t nbatch = 1) const {
TA_ASSERT(this->range().volume() * this->nbatch() ==
range.volume() * nbatch);
return Tensor(range, nbatch, this->data_);
}
/// @return a deep copy of `*this`
Tensor clone() const& {
Tensor result;
if (data_) {
if constexpr (detail::is_tensor_of_tensor_v<Tensor> &&
detail::is_ta_tensor_v<value_type>) {
auto fill = [](typename value_type::value_type* dst,
const typename value_type::value_type* src,
std::size_t n) {
for (std::size_t i = 0; i < n; ++i) dst[i] = src[i];
};
result = detail::arena_trivial_unary<Tensor>(*this, fill);
} else if constexpr (is_arena_tensor_v<value_type>) {
auto fill = [](typename value_type::value_type* dst,
const typename value_type::value_type* src,
std::size_t n) {
for (std::size_t i = 0; i < n; ++i) dst[i] = src[i];
};
result = detail::arena_trivial_unary<Tensor>(*this, fill);
} else {
result = detail::tensor_op<Tensor>(
[](const numeric_type value) -> numeric_type { return value; },
*this);
}
} else if (range_) { // corner case: data_ = null implies range_.volume()
// == 0;
TA_ASSERT(range_.volume() == 0);
result = Tensor(range_);
}
return result;
}
/// cloning an rvalue ref forwards the contents of this
/// @return a deep copy of `*this`
/// @post this is in a moved-from state
Tensor clone() && { return std::move(*this); }
template <typename T1,
typename std::enable_if<is_tensor<T1>::value>::type* = nullptr>
Tensor& operator=(const T1& other) {
*this = Tensor(detail::clone_range(other), 1, default_construct{false});
detail::inplace_tensor_op(
[](reference MADNESS_RESTRICT tr,
typename T1::const_reference MADNESS_RESTRICT t1) { tr = t1; },
*this, other);
return *this;
}
/// copy assignment operator
/// \param[in] other an object to copy data from
/// \post `*this` is a shallow copy of \p other ,
/// i.e. `*this == other && this->data()==other.data()`
Tensor& operator=(const Tensor& other) {
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->erase(
this,
make_string("TA::Tensor::operator=(const Tensor&)::data_.get()=",
data_.get()));
}
#endif
range_ = other.range_;
nbatch_ = other.nbatch_;
data_ = other.data_;
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->insert(
this,
make_string("TA::Tensor::operator=(const Tensor&)::data_.get()=",
data_.get()));
}
#endif
return *this;
}
/// move assignment operator
/// \param[in,out] other an object to move data from;
/// on return \p other is in empty (null) but not
/// necessarily default state
/// \post `other.empty()`
Tensor& operator=(Tensor&& other) {
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->erase(
this, make_string("TA::Tensor::operator=(Tensor&&)::data_.get()=",
data_.get()));
}
if (other.nbytes() >= trace_if_larger_than_) {
ptr_registry()->erase(
&other, make_string("TA::Tensor::operator=(Tensor&&)::data_.get()=",
data_.get()));
}
#endif
range_ = std::move(other.range_);
nbatch_ = std::move(other.nbatch_);
data_ = std::move(other.data_);
#ifdef TA_TENSOR_MEM_TRACE
if (nbytes() >= trace_if_larger_than_) {
ptr_registry()->insert(
this, make_string("TA::Tensor::operator=(Tensor&&)::data_.get()=",
data_.get()));
}
#endif
return *this;
}
/// Tensor range object accessor
/// \return The tensor range object
const range_type& range() const { return range_; }
/// Tensor dimension size accessor
/// \return The number of elements in the tensor
ordinal_type size() const { return (this->range().volume()); }
/// \return The number of elements in the tensor by summing up the sizes of
/// the batches.
ordinal_type total_size() const { return size() * nbatch(); }
/// Tensor data size (in bytes) accessor
/// \return The number of bytes occupied by this tensor's data
/// \warning this only returns valid value if this is a tensor of scalars
std::size_t nbytes() const {
return this->range().volume() * this->nbatch_ * sizeof(T);
}
/// Const element accessor
/// \tparam Ordinal an integer type that represents an ordinal
/// \param[in] ord an ordinal index
/// \return Const reference to the element at position \c ord .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is
/// included in the range, and `nbatch()==1`
template <typename Ordinal,
std::enable_if_t<std::is_integral<Ordinal>::value>* = nullptr>
const_reference operator[](const Ordinal ord) const {
TA_ASSERT(!this->empty());
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator[](index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
TA_ASSERT(this->nbatch() == 1);
TA_ASSERT(this->range_.includes_ordinal(ord));
return this->data()[ord];
}
/// Element accessor
/// \tparam Ordinal an integer type that represents an ordinal
/// \param[in] ord an ordinal index
/// \return Reference to the element at position \c ord .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is
/// included in the range, and `nbatch()==1`
template <typename Ordinal,
std::enable_if_t<std::is_integral<Ordinal>::value>* = nullptr>
reference operator[](const Ordinal ord) {
TA_ASSERT(!this->empty());
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator[](index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
TA_ASSERT(this->nbatch() == 1);
TA_ASSERT(this->range_.includes_ordinal(ord));
return this->data()[ord];
}
/// Const element accessor
/// \tparam Ordinal an integer type that represents an ordinal
/// \param[in] ord an ordinal index
/// \return Const reference to the element at position \c ord .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is
/// included in the range, and `nbatch()==1`
template <typename Ordinal,
std::enable_if_t<std::is_integral<Ordinal>::value>* = nullptr>
const_reference at_ordinal(const Ordinal ord) const {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
TA_ASSERT(this->range_.includes_ordinal(ord));
return this->data()[ord];
}
/// Element accessor
/// \tparam Ordinal an integer type that represents an ordinal
/// \param[in] ord an ordinal index
/// \return Reference to the element at position \c ord .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is
/// included in the range, and `nbatch()==1`
template <typename Ordinal,
std::enable_if_t<std::is_integral<Ordinal>::value>* = nullptr>
reference at_ordinal(const Ordinal ord) {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
TA_ASSERT(this->range_.includes_ordinal(ord));
return this->data()[ord];
}
/// Const element accessor
/// \tparam Index An integral range type
/// \param[in] i an index
/// \return Const reference to the element at position \c i .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p i is
/// included in the range, and `nbatch()==1`
template <typename Index,
std::enable_if_t<detail::is_integral_range_v<Index>>* = nullptr>
const_reference operator[](const Index& i) const {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
const auto iord = this->range_.ordinal(i);
TA_ASSERT(this->range_.includes_ordinal(iord));
return this->data()[iord];
}
/// Element accessor
/// \tparam Index An integral range type
/// \param[in] i an index
/// \return Reference to the element at position \c i .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p i is
/// included in the range, and `nbatch()==1`
template <typename Index,
std::enable_if_t<detail::is_integral_range_v<Index>>* = nullptr>
reference operator[](const Index& i) {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
const auto iord = this->range_.ordinal(i);
TA_ASSERT(this->range_.includes_ordinal(iord));
return this->data()[iord];
}
/// Const element accessor
/// \tparam Integer An integral type
/// \param[in] i an index
/// \return Const reference to the element at position \c i .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p i is
/// included in the range, and `nbatch()==1`
template <typename Integer,
std::enable_if_t<std::is_integral_v<Integer>>* = nullptr>
const_reference operator[](const std::initializer_list<Integer>& i) const {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
const auto iord = this->range_.ordinal(i);
TA_ASSERT(this->range_.includes_ordinal(iord));
return this->data()[iord];
}
/// Element accessor
/// \tparam Integer An integral type
/// \param[in] i an index
/// \return Reference to the element at position \c i .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p i is
/// included in the range, and `nbatch()==1`
template <typename Integer,
std::enable_if_t<std::is_integral_v<Integer>>* = nullptr>
reference operator[](const std::initializer_list<Integer>& i) {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
const auto iord = this->range_.ordinal(i);
TA_ASSERT(this->range_.includes_ordinal(iord));
return this->data()[iord];
}
/// Const element accessor
/// \tparam Ordinal an integer type that represents an ordinal
/// \param[in] ord an ordinal index
/// \return Const reference to the element at position \c ord .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is
/// included in the range, and `nbatch()==1`
template <typename Ordinal,
std::enable_if_t<std::is_integral_v<Ordinal>>* = nullptr>
const_reference operator()(const Ordinal& ord) const {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator()(index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
TA_ASSERT(this->range_.includes_ordinal(ord));
return this->data()[ord];
}
/// Element accessor
/// \tparam Ordinal an integer type that represents an ordinal
/// \param[in] ord an ordinal index
/// \return Reference to the element at position \c ord .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is
/// included in the range, and `nbatch()==1`
template <typename Ordinal,
std::enable_if_t<std::is_integral_v<Ordinal>>* = nullptr>
reference operator()(const Ordinal& ord) {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator()(index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
TA_ASSERT(this->range_.includes_ordinal(ord));
return this->data()[ord];
}
/// Const element accessor
/// \tparam Index An integral range type
/// \param[in] i an index
/// \return Const reference to the element at position \c i .
/// \note This asserts (using TA_ASSERT) that this is not empty, \p i is
/// included in the range, and `nbatch()==1`
template <typename Index,
std::enable_if_t<detail::is_integral_range_v<Index>>* = nullptr>
const_reference operator()(const Index& i) const {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
const auto iord = this->range_.ordinal(i);
TA_ASSERT(this->range_.includes_ordinal(iord));
return this->data()[iord];
}