Skip to content

Commit 15a6613

Browse files
committed
add tentative numpy actx support
1 parent 376e131 commit 15a6613

1 file changed

Lines changed: 53 additions & 1 deletion

File tree

arraycontext/impl/numpy/__init__.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,15 @@
3333
THE SOFTWARE.
3434
"""
3535

36+
from dataclasses import dataclass
37+
from functools import cached_property
3638
from typing import TYPE_CHECKING, Any, cast, overload
3739

3840
import numpy as np
3941
from typing_extensions import override
4042

4143
import loopy as lp
44+
from pytools.tag import normalize_tags
4245

4346
from arraycontext.container.traversal import (
4447
rec_map_array_container as rec_map_array_container,
@@ -47,10 +50,12 @@
4750
)
4851
from arraycontext.context import (
4952
ArrayContext,
53+
CSRMatrix as _BaseCSRMatrix,
5054
UntransformedCodeWarning,
5155
)
5256
from arraycontext.typing import (
5357
Array,
58+
ArrayOrContainer,
5459
ArrayOrContainerOrScalar,
5560
ArrayOrContainerOrScalarT,
5661
ContainerOrScalarT,
@@ -60,12 +65,17 @@
6065

6166

6267
if TYPE_CHECKING:
68+
import scipy.sparse
69+
6370
from pymbolic import Scalar
64-
from pytools.tag import ToTagSetConvertible
71+
from pytools.tag import Tag, ToTagSetConvertible
6572

6673
from arraycontext.typing import ArrayContainerT
6774

6875

76+
_EMPTY_TAG_SET: frozenset[Tag] = frozenset()
77+
78+
6979
class NumpyNonObjectArrayMetaclass(type):
7080
@override
7181
def __instancecheck__(cls, instance: object) -> bool:
@@ -76,6 +86,29 @@ class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass):
7686
pass
7787

7888

89+
@dataclass(frozen=True, eq=False, repr=False)
90+
class CSRMatrix(_BaseCSRMatrix):
91+
@cached_property
92+
def _np_matrix(self) -> scipy.sparse.csr_matrix:
93+
assert isinstance(self.elem_values, np.ndarray)
94+
assert isinstance(self.elem_col_indices, np.ndarray)
95+
assert isinstance(self.row_starts, np.ndarray)
96+
# FIXME: Not sure if the scipy dependency is OK or if it should just use the
97+
# call_loopy fallback? Currently getting errors with the loopy version:
98+
# loopy.diagnostic.LoopyError: One of the kernels in the program has
99+
# been preprocessed, cannot modify target now.
100+
from scipy.sparse import csr_matrix
101+
return csr_matrix(
102+
(self.elem_values, self.elem_col_indices, self.row_starts),
103+
shape=self.shape)
104+
105+
@override
106+
def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
107+
return cast(
108+
"ArrayOrContainer",
109+
rec_map_container(lambda ary: self._np_matrix @ ary, other))
110+
111+
79112
class NumpyArrayContext(ArrayContext):
80113
"""
81114
A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays.
@@ -199,6 +232,25 @@ def tag_axis(self,
199232
def einsum(self, spec, *args, arg_names=None, tagged=()):
200233
return np.einsum(spec, *args, optimize="optimal")
201234

235+
# FIXME: Not sure what type annotations to use for shape
236+
@override
237+
def make_csr_matrix(
238+
self,
239+
shape,
240+
elem_values: Array,
241+
elem_col_indices: Array,
242+
row_starts: Array,
243+
*,
244+
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
245+
axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix:
246+
tags = normalize_tags(tags)
247+
if axes is None:
248+
axes = (frozenset(), frozenset())
249+
return CSRMatrix(
250+
shape, elem_values, elem_col_indices, row_starts,
251+
tags=tags, axes=axes,
252+
_actx=self)
253+
202254
@property
203255
def permits_inplace_modification(self):
204256
return True

0 commit comments

Comments
 (0)