@@ -4651,6 +4651,7 @@ def matmul(x1: NDArray | np.ndarray, x2: NDArray, **kwargs: Any) -> NDArray | np
46514651 return result
46524652
46534653
4654+ # @profile
46544655def tensordot (
46554656 x1 : NDArray , x2 : NDArray , axes : int | tuple [Sequence [int ], Sequence [int ]] = 2 , ** kwargs : Any
46564657) -> NDArray :
@@ -4688,74 +4689,133 @@ def tensordot(
46884689 out: NDArray
46894690 An array containing the tensor contraction whose shape consists of the non-contracted axes (dimensions) of the first array x1, followed by the non-contracted axes (dimensions) of the second array x2. The returned array must have a data type determined by Type Promotion Rules.
46904691 """
4692+ fast_path = kwargs .pop ("fast_path" , None ) # for testing purposes
4693+
46914694 # Added this to pass array-api tests (which use internal getitem to check results)
46924695 if isinstance (x1 , np .ndarray ) and isinstance (x2 , np .ndarray ):
46934696 return np .tensordot (x1 , x2 , axes = axes )
46944697
4698+ x1 , x2 = blosc2 .asarray (x1 ), blosc2 .asarray (x2 )
4699+
46954700 if isinstance (axes , tuple ):
46964701 a_axes , b_axes = axes
4702+ a_axes = list (a_axes )
4703+ b_axes = list (b_axes )
46974704 if len (a_axes ) != len (b_axes ):
46984705 raise ValueError ("Lengths of reduction axes for x1 and x2 must be equal!" )
4699- order = np .argsort (a_axes )
4700- a_red_axes = [(i - x1 .ndim in a_axes ) or (i in a_axes ) for i in range (x1 .ndim )]
4701- b_red_axes = [(i - x2 .ndim in b_axes ) or (i in b_axes ) for i in range (x2 .ndim )]
4706+ # need to track order of b_axes; later we cycle through a_axes sorted for op_chunk
4707+ # a_sorted[inv_sort][b_sort] matches b_sorted since b_axes matches a_axes
4708+ inv_sort = np .argsort (np .argsort (a_axes ))
4709+ b_sort = np .argsort (b_axes )
4710+ order = inv_sort [b_sort ]
4711+ a_keep , b_keep = [True ] * x1 .ndim , [True ] * x2 .ndim
4712+ for i , j in zip (a_axes , b_axes , strict = False ):
4713+ i = x1 .ndim + i if i < 0 else i
4714+ j = x2 .ndim + j if j < 0 else j
4715+ a_keep [i ] = False
4716+ b_keep [j ] = False
4717+ a_axes = [] if a_axes == () else a_axes # handle no reduction
4718+ b_axes = [] if b_axes == () else b_axes # handle no reduction
47024719 elif isinstance (axes , int ):
47034720 if axes < 0 :
47044721 raise ValueError ("Integer axes argument must be nonnegative!" )
4705- order = np .arange (axes , dtype = int )
4706- a_red_axes = [i + axes >= x1 .ndim for i in range (x1 .ndim )]
4707- b_red_axes = [i < axes for i in range (x2 .ndim )]
4722+ order = np .arange (axes , dtype = int ) # no reordering required
4723+ a_axes = list (range (x1 .ndim - axes , x1 .ndim ))
4724+ b_axes = list (range (0 , axes ))
4725+ a_keep = [i + axes < x1 .ndim for i in range (x1 .ndim )]
4726+ b_keep = [i >= axes for i in range (x2 .ndim )]
47084727 else :
47094728 raise ValueError ("Axes argument must be two element tuple of sequences or an integer." )
47104729 x1shape = np .array (x1 .shape )
47114730 x2shape = np .array (x2 .shape )
4712- a_chunks_red = tuple (c for i , c in enumerate (x1 .chunks ) if a_red_axes [i ])
4713- if np .any (x1shape [a_red_axes ] != x2shape [b_red_axes ][order ]):
4731+ a_chunks_red = tuple (c for i , c in enumerate (x1 .chunks ) if not a_keep [i ])
4732+ a_shape_red = tuple (c for i , c in enumerate (x1 .shape ) if not a_keep [i ])
4733+
4734+ if np .any (x1shape [a_axes ] != x2shape [b_axes ]):
47144735 raise ValueError ("x1 and x2 must have same shapes along reduction dimensions" )
47154736
4716- a_axes = [not i for i in a_red_axes ]
4717- b_axes = [not i for i in b_red_axes ]
4718- result_shape = tuple (x1shape [a_axes ]) + tuple (x2shape [b_axes ])
4737+ result_shape = tuple (x1shape [a_keep ]) + tuple (x2shape [b_keep ])
47194738 result = blosc2 .zeros (result_shape , dtype = np .result_type (x1 , x2 ), ** kwargs )
47204739
47214740 op_chunks = [
4722- slice_to_chunktuple (slice (0 , s , 1 ), c )
4723- for s , c in zip (x1shape [a_red_axes ], a_chunks_red , strict = True )
4741+ slice_to_chunktuple (slice (0 , s , 1 ), c ) for s , c in zip (x1shape [a_axes ], a_chunks_red , strict = True )
47244742 ]
47254743 res_chunks = [
47264744 slice_to_chunktuple (s , c )
47274745 for s , c in zip ([slice (0 , r , 1 ) for r in result .shape ], result .chunks , strict = True )
47284746 ]
47294747 a_selection = (slice (None , None , 1 ),) * x1 .ndim
47304748 b_selection = (slice (None , None , 1 ),) * x2 .ndim
4749+
4750+ chunk_memory = np .prod (result .chunks ) * (
4751+ np .prod (x1shape [a_axes ]) * x1 .dtype .itemsize + np .prod (x2shape [b_axes ]) * x2 .dtype .itemsize
4752+ )
4753+ if chunk_memory < blosc2 .MAX_FAST_PATH_SIZE :
4754+ fast_path = True if fast_path is None else fast_path
4755+ fast_path = False if fast_path is None else fast_path # fast_path set via kwargs for testing
4756+
4757+ # adapted from numpy.tensordot
4758+ a_keep_axes = [i for i , k in enumerate (a_keep ) if k ]
4759+ b_keep_axes = [i for i , k in enumerate (b_keep ) if k ]
4760+ newaxes_a = a_keep_axes + a_axes
4761+ newaxes_b = b_axes + b_keep_axes
4762+
47314763 for rchunk in product (* res_chunks ):
47324764 res_chunk = tuple (
4733- slice (rc * rcs , (rc + 1 ) * rcs , 1 ) for rc , rcs in zip (rchunk , result .chunks , strict = True )
4765+ slice (rc * rcs , builtins .min ((rc + 1 ) * rcs , rshape ), 1 )
4766+ for rc , rcs , rshape in zip (rchunk , result .chunks , result .shape , strict = True )
47344767 )
47354768 rchunk_iter = iter (res_chunk )
4736- a_selection = tuple (
4737- next (rchunk_iter ) if a else as_ for as_ , a in zip (a_selection , a_axes , strict = True )
4738- )
4739- b_selection = tuple (
4740- next (rchunk_iter ) if b else bs_ for bs_ , b in zip (b_selection , b_axes , strict = True )
4741- )
4742- for ochunk in product (* op_chunks ):
4743- op_chunk = tuple (
4744- slice (rc * rcs , (rc + 1 ) * rcs , 1 ) for rc , rcs in zip (ochunk , a_chunks_red , strict = True )
4745- ) # use x1 chunk shape to iterate over reduction axes
4746- ochunk_iter = iter (op_chunk )
4747- a_selection = tuple (
4748- next (ochunk_iter ) if not a else as_ for as_ , a in zip (a_selection , a_axes , strict = True )
4749- )
4750- # have to permute to match order of a_axes
4751- order_iter = iter (order )
4752- b_selection = tuple (
4753- op_chunk [next (order_iter )] if not b else bs_
4754- for bs_ , b in zip (b_selection , b_axes , strict = True )
4755- )
4769+ a_selection = tuple (next (rchunk_iter ) if a else slice (None , None , 1 ) for a in a_keep )
4770+ b_selection = tuple (next (rchunk_iter ) if b else slice (None , None , 1 ) for b in b_keep )
4771+ res_chunks = tuple (s .stop - s .start for s in res_chunk )
4772+
4773+ if fast_path : # just load everything
47564774 bx1 = x1 [a_selection ]
47574775 bx2 = x2 [b_selection ]
4758- result [res_chunk ] += np .tensordot (bx1 , bx2 , axes = axes )
4776+ newshape_a = (
4777+ math .prod ([bx1 .shape [i ] for i in a_keep_axes ]),
4778+ math .prod ([bx1 .shape [a ] for a in a_axes ]),
4779+ )
4780+ newshape_b = (
4781+ math .prod ([bx2 .shape [b ] for b in b_axes ]),
4782+ math .prod ([bx2 .shape [i ] for i in b_keep_axes ]),
4783+ )
4784+ at = bx1 .transpose (newaxes_a ).reshape (newshape_a )
4785+ bt = bx2 .transpose (newaxes_b ).reshape (newshape_b )
4786+ res = np .dot (at , bt )
4787+ result [res_chunk ] += res .reshape (res_chunks )
4788+ else : # operands too big, have to go chunk-by-chunk
4789+ for ochunk in product (* op_chunks ):
4790+ op_chunk = tuple (
4791+ slice (rc * rcs , builtins .min ((rc + 1 ) * rcs , x1s ), 1 )
4792+ for rc , rcs , x1s in zip (ochunk , a_chunks_red , a_shape_red , strict = True )
4793+ ) # use x1 chunk shape to iterate over reduction axes
4794+ ochunk_iter = iter (op_chunk )
4795+ a_selection = tuple (
4796+ next (ochunk_iter ) if not a else as_ for as_ , a in zip (a_selection , a_keep , strict = True )
4797+ )
4798+ # have to permute to match order of a_axes
4799+ order_iter = iter (order )
4800+ b_selection = tuple (
4801+ op_chunk [next (order_iter )] if not b else bs_
4802+ for bs_ , b in zip (b_selection , b_keep , strict = True )
4803+ )
4804+ bx1 = x1 [a_selection ]
4805+ bx2 = x2 [b_selection ]
4806+ # adapted from numpy tensordot
4807+ newshape_a = (
4808+ math .prod ([bx1 .shape [i ] for i in a_keep_axes ]),
4809+ math .prod ([bx1 .shape [a ] for a in a_axes ]),
4810+ )
4811+ newshape_b = (
4812+ math .prod ([bx2 .shape [b ] for b in b_axes ]),
4813+ math .prod ([bx2 .shape [i ] for i in b_keep_axes ]),
4814+ )
4815+ at = bx1 .transpose (newaxes_a ).reshape (newshape_a )
4816+ bt = bx2 .transpose (newaxes_b ).reshape (newshape_b )
4817+ res = np .dot (at , bt )
4818+ result [res_chunk ] += res .reshape (res_chunks )
47594819 return result
47604820
47614821
0 commit comments