Skip to content

Commit 7949c17

Browse files
Enable _same_logical_tensors in _tensor_impl
1 parent 79d40f2 commit 7949c17

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

dpctl_ext/tensor/libtensor/source/tensor_ctors.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,15 @@ PYBIND11_MODULE(_tensor_impl, m)
430430
"Determines if the memory regions indexed by each array overlap",
431431
py::arg("array1"), py::arg("array2"));
432432

433-
// auto same_logical_tensors =
434-
// [](const dpctl::tensor::usm_ndarray &x1,
435-
// const dpctl::tensor::usm_ndarray &x2) -> bool {
436-
// auto const &same_logical_tensors = SameLogicalTensors();
437-
// return same_logical_tensors(x1, x2);
438-
// };
439-
// m.def("_same_logical_tensors", same_logical_tensors,
440-
// "Determines if the memory regions indexed by each array are the
441-
// same", py::arg("array1"), py::arg("array2"));
433+
auto same_logical_tensors =
434+
[](const dpctl::tensor::usm_ndarray &x1,
435+
const dpctl::tensor::usm_ndarray &x2) -> bool {
436+
auto const &same_logical_tensors = SameLogicalTensors();
437+
return same_logical_tensors(x1, x2);
438+
};
439+
m.def("_same_logical_tensors", same_logical_tensors,
440+
"Determines if the memory regions indexed by each array are the same",
441+
py::arg("array1"), py::arg("array2"));
442442

443443
// m.def("_place", &py_place, "", py::arg("dst"), py::arg("cumsum"),
444444
// py::arg("axis_start"), py::arg("axis_end"), py::arg("rhs"),

0 commit comments

Comments
 (0)