@@ -430,15 +430,15 @@ PYBIND11_MODULE(_tensor_impl, m)
430430 " Determines if the memory regions indexed by each array overlap" ,
431431 py::arg (" array1" ), py::arg (" array2" ));
432432
433- // auto same_logical_tensors =
434- // [](const dpctl::tensor::usm_ndarray &x1,
435- // const dpctl::tensor::usm_ndarray &x2) -> bool {
436- // auto const &same_logical_tensors = SameLogicalTensors();
437- // return same_logical_tensors(x1, x2);
438- // };
439- // m.def("_same_logical_tensors", same_logical_tensors,
440- // "Determines if the memory regions indexed by each array are the
441- // same", py::arg("array1"), py::arg("array2"));
433+ auto same_logical_tensors =
434+ [](const dpctl::tensor::usm_ndarray &x1,
435+ const dpctl::tensor::usm_ndarray &x2) -> bool {
436+ auto const &same_logical_tensors = SameLogicalTensors ();
437+ return same_logical_tensors (x1, x2);
438+ };
439+ m.def (" _same_logical_tensors" , same_logical_tensors,
440+ " Determines if the memory regions indexed by each array are the same " ,
441+ py::arg (" array1" ), py::arg (" array2" ));
442442
443443 // m.def("_place", &py_place, "", py::arg("dst"), py::arg("cumsum"),
444444 // py::arg("axis_start"), py::arg("axis_end"), py::arg("rhs"),
0 commit comments