7878
7979.. autoclass:: ArrayContext
8080
81+ .. autoclass:: SparseMatrix
82+ .. autoclass:: CSRMatrix
83+
8184.. autofunction:: tag_axes
8285
8386.. class:: P
114117"""
115118
116119
120+ import dataclasses
117121from abc import ABC , abstractmethod
118122from collections .abc import Callable , Hashable , Mapping
119123from typing import (
120124 TYPE_CHECKING ,
121125 Any ,
122126 ParamSpec ,
123127 TypeAlias ,
128+ cast ,
124129 overload ,
125130)
126131from warnings import warn
129134
130135from pytools import memoize_method
131136
137+ from arraycontext .container .traversal import (
138+ rec_map_container ,
139+ )
140+
132141
133142if TYPE_CHECKING :
134143 import numpy as np
135144 from numpy .typing import DTypeLike
136145
137146 import loopy
138- from pytools .tag import ToTagSetConvertible
147+ from pytools .tag import Tag , ToTagSetConvertible
139148
140149 from .fake_numpy import BaseFakeNumpyNamespace
141150 from .typing import (
142151 Array ,
143152 ArrayContainerT ,
144153 ArrayOrArithContainerOrScalarT ,
154+ ArrayOrContainer ,
145155 ArrayOrContainerOrScalar ,
146156 ArrayOrContainerOrScalarT ,
157+ ArrayOrScalar ,
147158 ContainerOrScalarT ,
148159 NumpyOrContainerOrScalar ,
149160 ScalarLike ,
152163
153164P = ParamSpec ("P" )
154165
166+ _EMPTY_TAG_SET : frozenset [Tag ] = frozenset ()
167+
168+
169+ @dataclasses .dataclass (frozen = True , eq = False , repr = False )
170+ class SparseMatrix (ABC ):
171+ shape : tuple [int , int ]
172+ tags : ToTagSetConvertible = dataclasses .field (kw_only = True )
173+ axes : tuple [ToTagSetConvertible , ...] = dataclasses .field (kw_only = True )
174+ _actx : ArrayContext = dataclasses .field (kw_only = True )
175+
176+ def __matmul__ (self , other : ArrayOrContainer ) -> ArrayOrContainer :
177+ return self ._actx .sparse_matmul (self , other )
178+
179+
180+ @dataclasses .dataclass (frozen = True , eq = False , repr = False )
181+ class CSRMatrix (SparseMatrix ):
182+ elem_values : Array
183+ elem_col_indices : Array
184+ row_starts : Array
185+
155186
156187# {{{ ArrayContext
157188
@@ -169,6 +200,8 @@ class ArrayContext(ABC):
169200 .. automethod:: to_numpy
170201 .. automethod:: call_loopy
171202 .. automethod:: einsum
203+ .. automethod:: make_csr_matrix
204+ .. automethod:: sparse_matmul
172205 .. attribute:: np
173206
174207 Provides access to a namespace that serves as a work-alike to
@@ -421,6 +454,166 @@ def einsum(self,
421454 )["out" ]
422455 return self .tag (tagged , out_ary )
423456
457+ def make_csr_matrix (
458+ self ,
459+ shape : tuple [int , int ],
460+ elem_values : Array ,
461+ elem_col_indices : Array ,
462+ row_starts : Array ,
463+ * ,
464+ tags : ToTagSetConvertible = _EMPTY_TAG_SET ,
465+ axes : tuple [ToTagSetConvertible , ...] | None = None ) -> CSRMatrix :
466+ """Return a sparse matrix in compressed sparse row (CSR) format, to be used
467+ with :meth:`sparse_matmul`.
468+
469+ :arg shape: the (two-dimensional) shape of the matrix
470+ :arg elem_values: a one-dimensional array containing the values of all of the
471+ nonzero entries of the matrix, grouped by row.
472+ :arg elem_col_indices: a one-dimensional array containing the column index
473+ values corresponding to each entry in *elem_values*.
474+ :arg row_starts: a one-dimensional array of length ``nrows+1``, where each entry
475+ gives the starting index in *elem_values* and *elem_col_indices* for the
476+ given row, with the last entry being equal to ``len(elem_values)``.
477+ """
478+ if axes is None :
479+ axes = (frozenset (), frozenset ())
480+
481+ return CSRMatrix (
482+ shape , elem_values , elem_col_indices , row_starts ,
483+ tags = tags , axes = axes ,
484+ _actx = self )
485+
486+ @memoize_method
487+ def _get_csr_matmul_prg (self , out_ndim : int ) -> loopy .TranslationUnit :
488+ import loopy as lp
489+
490+ out_extra_inames = tuple (f"i{ n } " for n in range (1 , out_ndim ))
491+ out_inames = ("irow" , * out_extra_inames )
492+ out_inames_set = frozenset (out_inames )
493+
494+ out_extra_shape_comp_names = tuple (f"n{ n } " for n in range (1 , out_ndim ))
495+ out_shape_comp_names = ("nrows" , * out_extra_shape_comp_names )
496+
497+ domains : list [str ] = []
498+ domains .append (
499+ "{ [" + "," .join (out_inames ) + "] : "
500+ + " and " .join (
501+ f"0 <= { iname } < { shape_comp_name } "
502+ for iname , shape_comp_name in zip (
503+ out_inames , out_shape_comp_names , strict = True ))
504+ + " }" )
505+ domains .append (
506+ "{ [iel] : iel_lbound <= iel < iel_ubound }" )
507+
508+ temporary_variables : Mapping [str , lp .TemporaryVariable ] = {
509+ "iel_lbound" : lp .TemporaryVariable (
510+ "iel_lbound" ,
511+ shape = (),
512+ ),
513+ "iel_ubound" : lp .TemporaryVariable (
514+ "iel_ubound" ,
515+ shape = (),
516+ )}
517+
518+ from loopy .kernel .instruction import make_assignment
519+ from pymbolic import var
520+ instructions : list [lp .Assignment | lp .CallInstruction ] = [
521+ make_assignment (
522+ (var ("iel_lbound" ),),
523+ var ("row_starts" )[var ("irow" )],
524+ id = "insn0" ,
525+ within_inames = out_inames_set ),
526+ make_assignment (
527+ (var ("iel_ubound" ),),
528+ var ("row_starts" )[var ("irow" ) + 1 ],
529+ id = "insn1" ,
530+ within_inames = out_inames_set ),
531+ make_assignment (
532+ (var ("out" )[tuple (var (iname ) for iname in out_inames )],),
533+ lp .Reduction (
534+ "sum" ,
535+ (var ("iel" ),),
536+ var ("elem_values" )[var ("iel" ),]
537+ * var ("array" )[(
538+ var ("elem_col_indices" )[var ("iel" ),],
539+ * (var (iname ) for iname in out_extra_inames ))]),
540+ id = "insn2" ,
541+ within_inames = out_inames_set ,
542+ depends_on = frozenset ({"insn0" , "insn1" }))]
543+
544+ from loopy .version import MOST_RECENT_LANGUAGE_VERSION
545+
546+ from .loopy import _DEFAULT_LOOPY_OPTIONS
547+
548+ knl = lp .make_kernel (
549+ domains = domains ,
550+ instructions = instructions ,
551+ temporary_variables = temporary_variables ,
552+ kernel_data = [
553+ lp .ValueArg ("nrows" , is_input = True ),
554+ lp .ValueArg ("ncols" , is_input = True ),
555+ lp .ValueArg ("nels" , is_input = True ),
556+ * (
557+ lp .ValueArg (shape_comp_name , is_input = True )
558+ for shape_comp_name in out_extra_shape_comp_names ),
559+ lp .GlobalArg ("elem_values" , shape = (var ("nels" ),), is_input = True ),
560+ lp .GlobalArg ("elem_col_indices" , shape = (var ("nels" ),), is_input = True ),
561+ lp .GlobalArg ("row_starts" , shape = lp .auto , is_input = True ),
562+ lp .GlobalArg (
563+ "array" ,
564+ shape = (
565+ var ("ncols" ),
566+ * (
567+ var (shape_comp_name )
568+ for shape_comp_name in out_extra_shape_comp_names ),),
569+ is_input = True ),
570+ lp .GlobalArg (
571+ "out" ,
572+ shape = (
573+ var ("nrows" ),
574+ * (
575+ var (shape_comp_name )
576+ for shape_comp_name in out_extra_shape_comp_names ),),
577+ is_input = False ),
578+ ...],
579+ name = "csr_matmul_kernel" ,
580+ lang_version = MOST_RECENT_LANGUAGE_VERSION ,
581+ options = _DEFAULT_LOOPY_OPTIONS ,
582+ default_order = lp .auto ,
583+ default_offset = lp .auto ,
584+ )
585+
586+ idx_dtype = knl .default_entrypoint .index_dtype
587+
588+ return lp .add_and_infer_dtypes (
589+ knl ,
590+ {
591+ "," .join ([
592+ "ncols" , "nrows" , "nels" ,
593+ * out_extra_shape_comp_names ]): idx_dtype ,
594+ "elem_col_indices,row_starts" : idx_dtype })
595+
596+ def sparse_matmul (
597+ self , x1 : SparseMatrix , x2 : ArrayOrContainer ) -> ArrayOrContainer :
598+ """Multiply a sparse matrix by an array.
599+
600+ :arg x1: the sparse matrix.
601+ :arg x2: the array.
602+ """
603+ if isinstance (x1 , CSRMatrix ):
604+ def _matmul (ary : ArrayOrScalar ) -> ArrayOrScalar :
605+ assert self .is_array_type (ary )
606+ prg = self ._get_csr_matmul_prg (len (ary .shape ))
607+ return self .call_loopy (
608+ prg , elem_values = x1 .elem_values ,
609+ elem_col_indices = x1 .elem_col_indices ,
610+ row_starts = x1 .row_starts , array = ary )["out" ]
611+
612+ return cast ("ArrayOrContainer" , rec_map_container (_matmul , x2 ))
613+
614+ else :
615+ raise TypeError (f"unrecognized sparse matrix type '{ type (x1 ).__name__ } '" )
616+
424617 @abstractmethod
425618 def clone (self ) -> Self :
426619 """If possible, return a version of *self* that is semantically
0 commit comments