Skip to content

Commit e994fc0

Browse files
committed
docs: Address review feedback on lazy reducer pitfall section
- Only force evaluation of the returned expression; intermediate lazy expressions are safe as long as they do not outlive the function - Use auto for shifted/expVals/sumExp in the doc example and tests - Remove misleading "Alternatively" sections (eval on intermediates is not required; evaluation_strategy::immediate does not fix the issue) - Fix logSoftmax_eval test variant to drop unnecessary xt::eval wrappers
1 parent 420bc89 commit e994fc0

File tree

2 files changed

+16
-27
lines changed

2 files changed

+16
-27
lines changed

docs/source/pitfall.rst

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,32 +92,21 @@ is still an lvalue and thus captured by reference.
9292
that hold references to local variables. When the function returns, these local
9393
variables are destroyed, and the returned expression contains dangling references.
9494

95-
The fix is to use explicit container types to force evaluation:
95+
The fix is to force evaluation of only the returned expression — intermediate
96+
lazy expressions are safe as long as they do not outlive the function:
9697

9798
.. code::
9899
99100
template <typename T>
100101
xt::xtensor<T, 2> logSoftmax(const xt::xtensor<T, 2> &matrix)
101102
{
102103
xt::xtensor<T, 2> maxVals = xt::amax(matrix, {1}, xt::keep_dims);
103-
xt::xtensor<T, 2> shifted = matrix - maxVals;
104-
xt::xtensor<T, 2> expVals = xt::exp(shifted);
105-
xt::xtensor<T, 2> sumExp = xt::sum(expVals, {1}, xt::keep_dims);
106-
return shifted - xt::log(sumExp);
104+
auto shifted = matrix - maxVals;
105+
auto expVals = xt::exp(shifted);
106+
auto sumExp = xt::sum(expVals, {1}, xt::keep_dims);
107+
return xt::xtensor<T, 2>(shifted - xt::log(sumExp));
107108
}
108109
109-
Alternatively, you can use :cpp:func:`xt::eval` to force evaluation:
110-
111-
.. code::
112-
113-
auto shifted = xt::eval(matrix - maxVals);
114-
115-
Or use the immediate evaluation strategy for reducers:
116-
117-
.. code::
118-
119-
auto sumExp = xt::sum(expVals, {1}, xt::evaluation_strategy::immediate | xt::keep_dims);
120-
121110
Random numbers not consistent
122111
-----------------------------
123112

test/test_xmath.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -976,26 +976,26 @@ namespace xt
976976
TEST(xmath, issue_2871_intermediate_result_handling)
977977
{
978978
// This test verifies the correct pattern for using reducers with
979-
// intermediate results. Using 'auto' with lazy expressions can lead
980-
// to dangling references when the function returns.
979+
// intermediate results. Returning a lazy expression from a function can lead
980+
// to dangling references — only the returned expression must be evaluated.
981981

982-
// The CORRECT way: use explicit container types for intermediate results
982+
// The CORRECT way: use auto for intermediates, force evaluation only at return
983983
auto logSoftmax_correct = [](const xt::xtensor<double, 2>& matrix)
984984
{
985985
xt::xtensor<double, 2> maxVals = xt::amax(matrix, {1}, xt::keep_dims);
986-
xt::xtensor<double, 2> shifted = matrix - maxVals;
987-
xt::xtensor<double, 2> expVals = xt::exp(shifted);
988-
xt::xtensor<double, 2> sumExp = xt::sum(expVals, {1}, xt::keep_dims);
986+
auto shifted = matrix - maxVals;
987+
auto expVals = xt::exp(shifted);
988+
auto sumExp = xt::sum(expVals, {1}, xt::keep_dims);
989989
return xt::xtensor<double, 2>(shifted - xt::log(sumExp));
990990
};
991991

992-
// Alternative CORRECT way: use xt::eval for intermediate results
992+
// Alternative CORRECT way: use xt::eval on the reducer result
993993
auto logSoftmax_eval = [](const xt::xtensor<double, 2>& matrix)
994994
{
995995
auto maxVals = xt::eval(xt::amax(matrix, {1}, xt::keep_dims));
996-
auto shifted = xt::eval(matrix - maxVals);
997-
auto expVals = xt::eval(xt::exp(shifted));
998-
auto sumExp = xt::eval(xt::sum(expVals, {1}, xt::keep_dims));
996+
auto shifted = matrix - maxVals;
997+
auto expVals = xt::exp(shifted);
998+
auto sumExp = xt::sum(expVals, {1}, xt::keep_dims);
999999
return xt::xtensor<double, 2>(shifted - xt::log(sumExp));
10001000
};
10011001

0 commit comments

Comments
 (0)