diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 03ce5da0a..961806b15 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -24,6 +24,7 @@ short, uint8, ) +from infinicore.ntops import use_ntops from infinicore.ops.matmul import matmul from infinicore.ops.rearrange import rearrange from infinicore.tensor import ( @@ -62,6 +63,8 @@ "long", "short", "uint8", + # `ntops` integration. + "use_ntops", # Operations. "matmul", "rearrange", diff --git a/python/infinicore/ntops.py b/python/infinicore/ntops.py new file mode 100644 index 000000000..19a600aae --- /dev/null +++ b/python/infinicore/ntops.py @@ -0,0 +1,55 @@ +import sys + +import infinicore + + +def use_ntops(): + import ntops + + return _TemporaryAttributes( + (("ntops.torch.torch", infinicore),) + + tuple( + (f"infinicore.{op_name}", getattr(ntops.torch, op_name)) + for op_name in ntops.torch.__all__ + ) + ) + + +class _TemporaryAttributes: + def __init__(self, attribute_mappings): + self._attribute_mappings = attribute_mappings + + self._original_values = {} + + def __enter__(self): + for attr_path, new_value in self._attribute_mappings: + parent, attr_name = self._resolve_path(attr_path) + + try: + self._original_values[attr_path] = getattr(parent, attr_name) + except AttributeError: + pass + + setattr(parent, attr_name, new_value) + + return self + + def __exit__(self, exc_type, exc_value, traceback): + for attr_path, _ in self._attribute_mappings: + parent, attr_name = self._resolve_path(attr_path) + + if attr_path in self._original_values: + setattr(parent, attr_name, self._original_values[attr_path]) + else: + delattr(parent, attr_name) + + @staticmethod + def _resolve_path(path): + *parent_parts, attr_name = path.split(".") + + curr = sys.modules[parent_parts[0]] + + for part in parent_parts[1:]: + curr = getattr(curr, part) + + return curr, attr_name diff --git a/python/infinicore/tensor.py b/python/infinicore/tensor.py index 2df6df681..d40d8d10d 100644 --- a/python/infinicore/tensor.py +++ b/python/infinicore/tensor.py @@ -32,7 +32,7 @@ def ndim(self): return self._underlying.ndim def data_ptr(self): - return self._underlying.data_ptr + return self._underlying.data_ptr() def size(self, dim=None): if dim is None: diff --git a/src/infinicore/pybind11/tensor.hpp b/src/infinicore/pybind11/tensor.hpp index b7e50d561..36aea199c 100644 --- a/src/infinicore/pybind11/tensor.hpp +++ b/src/infinicore/pybind11/tensor.hpp @@ -17,7 +17,7 @@ inline void bind(py::module &m) { .def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); }) .def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); }) - .def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast(tensor->data()); }) + .def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast(tensor->data()); }) .def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); }) .def("stride", [](const Tensor &tensor, std::size_t dim) { return tensor->stride(dim); }) .def("numel", [](const Tensor &tensor) { return tensor->numel(); })