Skip to content

Commit c70fa07

Browse files
authored
Merge pull request #556 from ValeevGroup/evaleev/fix/arena-tot-binary-null-inner
arena: handle mismatched null inner cells in trivial binary ToT ops
2 parents 75573cf + 52f5239 commit c70fa07

3 files changed

Lines changed: 222 additions & 13 deletions

File tree

src/TiledArray/tensor/arena_kernels.h

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,21 +316,44 @@ OuterTensor arena_trivial_binary(const LeftTensor& left,
316316
using inner_range_t = typename OuterTensor::value_type::range_type;
317317
TA_ASSERT(left.range().volume() == right.range().volume());
318318
TA_ASSERT(left.nbatch() == right.nbatch());
319-
auto range_fn = [&left](std::size_t ord) -> inner_range_t {
319+
// Union sparsity: a result cell is present if *either* operand cell is.
320+
// ToT arrays with the same outer shape can still differ in which inner cells
321+
// are populated within an outer tile (e.g. occ_tile_size>1 aggregates several
322+
// pairs, some screened to null). A cell present in only one operand is
323+
// combined against an implicit zero slab below -- correct for the linear ops
324+
// (add: l+0 / 0+r; subt: l-0 / 0-r) and numerically correct for mult (l*0=0,
325+
// emitted as an explicit zero tile). Without this, a lone-left cell would
326+
// read a null right slab (segfault) and a lone-right cell would be silently
327+
// dropped, losing that addend.
328+
auto range_fn = [&left, &right](std::size_t ord) -> inner_range_t {
320329
const auto& l = left.data()[ord];
321-
return l.empty() ? inner_range_t{} : l.range();
330+
if (!l.empty()) return l.range();
331+
const auto& r = right.data()[ord];
332+
return r.empty() ? inner_range_t{} : r.range();
322333
};
323334
OuterTensor result = arena_outer_init<OuterTensor>(
324335
left.range(), left.nbatch(), range_fn, alignof(elem_t),
325336
/*zero_init=*/false);
326337
const std::size_t N_cells = left.range().volume() * left.nbatch();
338+
std::vector<elem_t> zeros; // grown lazily; implicit-zero slab for lone cells
327339
for (std::size_t ord = 0; ord < N_cells; ++ord) {
328340
auto& dst = result.data()[ord];
329341
if (dst.empty()) continue;
330-
TA_ASSERT(left.data()[ord].size() == right.data()[ord].size());
331-
TA_ASSERT(left.data()[ord].size() == dst.size());
332-
fill_op(dst.data(), left.data()[ord].data(), right.data()[ord].data(),
333-
dst.size());
342+
const auto& l = left.data()[ord];
343+
const auto& r = right.data()[ord];
344+
const std::size_t n = dst.size();
345+
const bool have_l = !l.empty();
346+
const bool have_r = !r.empty();
347+
TA_ASSERT(!have_l || l.size() == n);
348+
TA_ASSERT(!have_r || r.size() == n);
349+
if (have_l && have_r) {
350+
fill_op(dst.data(), l.data(), r.data(), n);
351+
} else {
352+
if (zeros.size() < n) zeros.assign(n, elem_t{});
353+
const elem_t* l_ptr = have_l ? l.data() : zeros.data();
354+
const elem_t* r_ptr = have_r ? r.data() : zeros.data();
355+
fill_op(dst.data(), l_ptr, r_ptr, n);
356+
}
334357
}
335358
return result;
336359
}

tests/arena_tensor_kernels.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,4 +725,100 @@ BOOST_AUTO_TEST_CASE(tot_axpy_to_accumulates_scaled_operand) {
725725
}
726726
}
727727

728+
// --- mismatched null-inner-cell coverage --------------------------------
729+
// occ_tile_size>1 (and any block-sparse ToT) can produce two operands with
730+
// the same outer shape but different *inner*-cell sparsity within an outer
731+
// tile. Regression coverage for the bug where arena_trivial_binary sized the
732+
// result by the left operand only and read the right unconditionally: a cell
733+
// present in left but null in right read a null slab (segfault), and a cell
734+
// present in right but null in left was silently dropped.
735+
736+
namespace {
737+
738+
/// Build an Outer of `n_outer` cells; cell `ord` is null iff `present[ord]`
739+
/// is false, otherwise a length-`n_inner` ArenaTensor filled deterministically.
740+
Outer make_outer_sparse(std::size_t n_outer, std::size_t n_inner, double base,
741+
const std::vector<bool>& present) {
742+
TA::Range outer_r{static_cast<long>(n_outer)};
743+
auto shape_fn = [n_inner, &present](std::size_t ord) {
744+
return present[ord] ? TA::Range{static_cast<long>(n_inner)} : TA::Range();
745+
};
746+
Outer outer = TA::detail::arena_outer_init<Outer>(outer_r, 1, shape_fn);
747+
for (std::size_t ord = 0; ord < n_outer; ++ord) {
748+
Inner& inner = outer.data()[ord];
749+
if (!inner) continue;
750+
for (std::size_t i = 0; i < inner.size(); ++i)
751+
inner.data()[i] = base + ord * 100.0 + i;
752+
}
753+
return outer;
754+
}
755+
756+
// L present on {0,1,2}, R present on {1,2,4}, over 5 outer cells:
757+
// 0 = lone-left, 1&2 = both, 3 = both-null, 4 = lone-right.
758+
constexpr std::size_t kNo = 5, kNi = 4;
759+
const std::vector<bool> kLpresent{true, true, true, false, false};
760+
const std::vector<bool> kRpresent{false, true, true, false, true};
761+
762+
} // namespace
763+
764+
BOOST_AUTO_TEST_CASE(trivial_add_mismatched_null_inners) {
765+
Outer L = make_outer_sparse(kNo, kNi, 1.0, kLpresent);
766+
Outer R = make_outer_sparse(kNo, kNi, 0.5, kRpresent);
767+
Outer sum = L.add(R); // must not segfault on lone-left cell 0
768+
for (std::size_t ord = 0; ord < kNo; ++ord) {
769+
const Inner &l = L.data()[ord], &r = R.data()[ord], &d = sum.data()[ord];
770+
const bool hl = bool(l), hr = bool(r);
771+
if (!hl && !hr) {
772+
BOOST_CHECK(!d); // both null -> null result
773+
} else {
774+
BOOST_REQUIRE(bool(d));
775+
for (std::size_t i = 0; i < d.size(); ++i) {
776+
const double lv = hl ? l.data()[i] : 0.0;
777+
const double rv = hr ? r.data()[i] : 0.0;
778+
BOOST_CHECK_EQUAL(d.data()[i], lv + rv); // union: lone-right kept too
779+
}
780+
}
781+
}
782+
}
783+
784+
BOOST_AUTO_TEST_CASE(trivial_subt_mismatched_null_inners) {
785+
Outer L = make_outer_sparse(kNo, kNi, 5.0, kLpresent);
786+
Outer R = make_outer_sparse(kNo, kNi, 1.0, kRpresent);
787+
Outer diff = L.subt(R);
788+
for (std::size_t ord = 0; ord < kNo; ++ord) {
789+
const Inner &l = L.data()[ord], &r = R.data()[ord], &d = diff.data()[ord];
790+
const bool hl = bool(l), hr = bool(r);
791+
if (!hl && !hr) {
792+
BOOST_CHECK(!d);
793+
} else {
794+
BOOST_REQUIRE(bool(d));
795+
for (std::size_t i = 0; i < d.size(); ++i) {
796+
const double lv = hl ? l.data()[i] : 0.0;
797+
const double rv = hr ? r.data()[i] : 0.0;
798+
BOOST_CHECK_EQUAL(d.data()[i], lv - rv); // lone-right -> -r
799+
}
800+
}
801+
}
802+
}
803+
804+
BOOST_AUTO_TEST_CASE(trivial_mult_mismatched_null_inners) {
805+
Outer L = make_outer_sparse(kNo, kNi, 2.0, kLpresent);
806+
Outer R = make_outer_sparse(kNo, kNi, 0.5, kRpresent);
807+
Outer prod = L.mult(R);
808+
for (std::size_t ord = 0; ord < kNo; ++ord) {
809+
const Inner &l = L.data()[ord], &r = R.data()[ord], &d = prod.data()[ord];
810+
const bool hl = bool(l), hr = bool(r);
811+
if (hl && hr) {
812+
BOOST_REQUIRE(bool(d));
813+
for (std::size_t i = 0; i < d.size(); ++i)
814+
BOOST_CHECK_EQUAL(d.data()[i], l.data()[i] * r.data()[i]);
815+
} else if (bool(d)) {
816+
// a lone cell multiplies against an implicit zero -> a zero tile
817+
// (numerically equivalent to absent); tolerate either policy.
818+
for (std::size_t i = 0; i < d.size(); ++i)
819+
BOOST_CHECK_EQUAL(d.data()[i], 0.0);
820+
}
821+
}
822+
}
823+
728824
BOOST_AUTO_TEST_SUITE_END()

tests/arena_tot_trivial.cpp

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ bool inners_share_one_slab(const outer_t& tot) {
5555
return true;
5656
}
5757

58-
}
58+
} // namespace
5959

6060
BOOST_AUTO_TEST_SUITE(arena_tot_trivial_suite, TA_UT_LABEL_SERIAL)
6161

@@ -88,8 +88,8 @@ BOOST_AUTO_TEST_CASE(add_bit_equal_and_one_slab) {
8888
for (std::size_t ord = 0; ord < L.range().volume(); ++ord) {
8989
inner_t inner((L.data() + ord)->range());
9090
for (std::size_t i = 0; i < inner.range().volume(); ++i)
91-
inner.at_ordinal(i) = (L.data() + ord)->at_ordinal(i) +
92-
(R.data() + ord)->at_ordinal(i);
91+
inner.at_ordinal(i) =
92+
(L.data() + ord)->at_ordinal(i) + (R.data() + ord)->at_ordinal(i);
9393
*(baseline.data() + ord) = std::move(inner);
9494
}
9595
BOOST_CHECK(tot_equal(arena_result, baseline));
@@ -104,8 +104,8 @@ BOOST_AUTO_TEST_CASE(subt_bit_equal_and_one_slab) {
104104
for (std::size_t ord = 0; ord < L.range().volume(); ++ord) {
105105
inner_t inner((L.data() + ord)->range());
106106
for (std::size_t i = 0; i < inner.range().volume(); ++i)
107-
inner.at_ordinal(i) = (L.data() + ord)->at_ordinal(i) -
108-
(R.data() + ord)->at_ordinal(i);
107+
inner.at_ordinal(i) =
108+
(L.data() + ord)->at_ordinal(i) - (R.data() + ord)->at_ordinal(i);
109109
*(baseline.data() + ord) = std::move(inner);
110110
}
111111
BOOST_CHECK(tot_equal(arena_result, baseline));
@@ -120,8 +120,8 @@ BOOST_AUTO_TEST_CASE(mult_elementwise_bit_equal_and_one_slab) {
120120
for (std::size_t ord = 0; ord < L.range().volume(); ++ord) {
121121
inner_t inner((L.data() + ord)->range());
122122
for (std::size_t i = 0; i < inner.range().volume(); ++i)
123-
inner.at_ordinal(i) = (L.data() + ord)->at_ordinal(i) *
124-
(R.data() + ord)->at_ordinal(i);
123+
inner.at_ordinal(i) =
124+
(L.data() + ord)->at_ordinal(i) * (R.data() + ord)->at_ordinal(i);
125125
*(baseline.data() + ord) = std::move(inner);
126126
}
127127
BOOST_CHECK(tot_equal(arena_result, baseline));
@@ -141,4 +141,94 @@ BOOST_AUTO_TEST_CASE(arena_outlives_source) {
141141
(9.0 + ord * 100.0 + i) * 2.0);
142142
}
143143

144+
// --- mismatched null-inner-cell coverage (non-arena inner) ---------------
145+
// Same kernel (arena_trivial_binary) backs Tensor<Tensor<double>>; exercise
146+
// the union-sparsity / implicit-zero path with mismatched per-cell nulls.
147+
// An unassigned outer cell is a default (empty) inner Tensor.
148+
149+
namespace {
150+
151+
/// `present[ord]==false` leaves cell `ord` a null (empty) inner tensor.
152+
outer_t make_tot_sparse(std::size_t N_outer, std::size_t n_inner, double base,
153+
const std::vector<bool>& present) {
154+
outer_t outer(TA::Range{static_cast<long>(N_outer)}, 1);
155+
for (std::size_t ord = 0; ord < N_outer; ++ord) {
156+
if (!present[ord]) continue; // leave default-constructed -> empty
157+
inner_t inner(TA::Range{static_cast<long>(n_inner)});
158+
for (std::size_t i = 0; i < n_inner; ++i)
159+
inner.at_ordinal(i) = base + ord * 100.0 + i;
160+
*(outer.data() + ord) = std::move(inner);
161+
}
162+
return outer;
163+
}
164+
165+
// 0 = lone-left, 1&2 = both, 3 = both-null, 4 = lone-right.
166+
const std::vector<bool> nz_L{true, true, true, false, false};
167+
const std::vector<bool> nz_R{false, true, true, false, true};
168+
169+
} // namespace
170+
171+
BOOST_AUTO_TEST_CASE(add_mismatched_null_inners) {
172+
outer_t L = make_tot_sparse(5, 4, 1.0, nz_L);
173+
outer_t R = make_tot_sparse(5, 4, 0.5, nz_R);
174+
outer_t sum = L.add(R); // must not segfault on lone-left cell 0
175+
for (std::size_t ord = 0; ord < 5; ++ord) {
176+
const inner_t& l = *(L.data() + ord);
177+
const inner_t& r = *(R.data() + ord);
178+
const inner_t& d = *(sum.data() + ord);
179+
const bool hl = !l.empty(), hr = !r.empty();
180+
if (!hl && !hr) {
181+
BOOST_CHECK(d.empty());
182+
} else {
183+
BOOST_REQUIRE(!d.empty());
184+
for (std::size_t i = 0; i < d.range().volume(); ++i) {
185+
const double lv = hl ? l.at_ordinal(i) : 0.0;
186+
const double rv = hr ? r.at_ordinal(i) : 0.0;
187+
BOOST_CHECK_EQUAL(d.at_ordinal(i), lv + rv);
188+
}
189+
}
190+
}
191+
}
192+
193+
BOOST_AUTO_TEST_CASE(subt_mismatched_null_inners) {
194+
outer_t L = make_tot_sparse(5, 4, 5.0, nz_L);
195+
outer_t R = make_tot_sparse(5, 4, 1.0, nz_R);
196+
outer_t diff = L.subt(R);
197+
for (std::size_t ord = 0; ord < 5; ++ord) {
198+
const inner_t& l = *(L.data() + ord);
199+
const inner_t& r = *(R.data() + ord);
200+
const inner_t& d = *(diff.data() + ord);
201+
const bool hl = !l.empty(), hr = !r.empty();
202+
if (!hl && !hr) {
203+
BOOST_CHECK(d.empty());
204+
} else {
205+
BOOST_REQUIRE(!d.empty());
206+
for (std::size_t i = 0; i < d.range().volume(); ++i) {
207+
const double lv = hl ? l.at_ordinal(i) : 0.0;
208+
const double rv = hr ? r.at_ordinal(i) : 0.0;
209+
BOOST_CHECK_EQUAL(d.at_ordinal(i), lv - rv);
210+
}
211+
}
212+
}
213+
}
214+
215+
BOOST_AUTO_TEST_CASE(mult_mismatched_null_inners) {
216+
outer_t L = make_tot_sparse(5, 4, 2.0, nz_L);
217+
outer_t R = make_tot_sparse(5, 4, 0.5, nz_R);
218+
outer_t prod = L.mult(R);
219+
for (std::size_t ord = 0; ord < 5; ++ord) {
220+
const inner_t& l = *(L.data() + ord);
221+
const inner_t& r = *(R.data() + ord);
222+
const inner_t& d = *(prod.data() + ord);
223+
if (!l.empty() && !r.empty()) {
224+
BOOST_REQUIRE(!d.empty());
225+
for (std::size_t i = 0; i < d.range().volume(); ++i)
226+
BOOST_CHECK_EQUAL(d.at_ordinal(i), l.at_ordinal(i) * r.at_ordinal(i));
227+
} else if (!d.empty()) {
228+
for (std::size_t i = 0; i < d.range().volume(); ++i)
229+
BOOST_CHECK_EQUAL(d.at_ordinal(i), 0.0);
230+
}
231+
}
232+
}
233+
144234
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)