3333THE SOFTWARE.
3434"""
3535
36+ from dataclasses import dataclass
37+ from functools import cached_property
3638from typing import TYPE_CHECKING , Any , cast , overload
3739
3840import numpy as np
3941from typing_extensions import override
4042
4143import loopy as lp
44+ from pytools .tag import normalize_tags
4245
4346from arraycontext .container .traversal import (
4447 rec_map_array_container as rec_map_array_container ,
4750)
4851from arraycontext .context import (
4952 ArrayContext ,
53+ CSRMatrix as _BaseCSRMatrix ,
5054 UntransformedCodeWarning ,
5155)
5256from arraycontext .typing import (
5357 Array ,
58+ ArrayOrContainer ,
5459 ArrayOrContainerOrScalar ,
5560 ArrayOrContainerOrScalarT ,
5661 ContainerOrScalarT ,
6065
6166
6267if TYPE_CHECKING :
68+ import scipy .sparse
69+
6370 from pymbolic import Scalar
64- from pytools .tag import ToTagSetConvertible
71+ from pytools .tag import Tag , ToTagSetConvertible
6572
6673 from arraycontext .typing import ArrayContainerT
6774
6875
76+ _EMPTY_TAG_SET : frozenset [Tag ] = frozenset ()
77+
78+
6979class NumpyNonObjectArrayMetaclass (type ):
7080 @override
7181 def __instancecheck__ (cls , instance : object ) -> bool :
@@ -76,6 +86,29 @@ class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass):
7686 pass
7787
7888
89+ @dataclass (frozen = True , eq = False , repr = False )
90+ class CSRMatrix (_BaseCSRMatrix ):
91+ @cached_property
92+ def _np_matrix (self ) -> scipy .sparse .csr_matrix :
93+ assert isinstance (self .elem_values , np .ndarray )
94+ assert isinstance (self .elem_col_indices , np .ndarray )
95+ assert isinstance (self .row_starts , np .ndarray )
96+ # FIXME: Not sure if the scipy dependency is OK or if it should just use the
97+ # call_loopy fallback? Currently getting errors with the loopy version:
98+ # loopy.diagnostic.LoopyError: One of the kernels in the program has
99+ # been preprocessed, cannot modify target now.
100+ from scipy .sparse import csr_matrix
101+ return csr_matrix (
102+ (self .elem_values , self .elem_col_indices , self .row_starts ),
103+ shape = self .shape )
104+
105+ @override
106+ def __matmul__ (self , other : ArrayOrContainer ) -> ArrayOrContainer :
107+ return cast (
108+ "ArrayOrContainer" ,
109+ rec_map_container (lambda ary : self ._np_matrix @ ary , other ))
110+
111+
79112class NumpyArrayContext (ArrayContext ):
80113 """
81114 A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays.
@@ -199,6 +232,25 @@ def tag_axis(self,
199232 def einsum (self , spec , * args , arg_names = None , tagged = ()):
200233 return np .einsum (spec , * args , optimize = "optimal" )
201234
235+ # FIXME: Not sure what type annotations to use for shape
236+ @override
237+ def make_csr_matrix (
238+ self ,
239+ shape ,
240+ elem_values : Array ,
241+ elem_col_indices : Array ,
242+ row_starts : Array ,
243+ * ,
244+ tags : ToTagSetConvertible = _EMPTY_TAG_SET ,
245+ axes : tuple [ToTagSetConvertible , ...] | None = None ) -> CSRMatrix :
246+ tags = normalize_tags (tags )
247+ if axes is None :
248+ axes = (frozenset (), frozenset ())
249+ return CSRMatrix (
250+ shape , elem_values , elem_col_indices , row_starts ,
251+ tags = tags , axes = axes ,
252+ _actx = self )
253+
202254 @property
203255 def permits_inplace_modification (self ):
204256 return True
0 commit comments