@@ -1525,6 +1525,12 @@ def __mul__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
15251525 def __rmul__ (self , other : TyArrayBase | PyScalar ) -> TyArrayBase :
15261526 return self ._apply (other , op .mul , forward = False , result_type = TyArrayNumber )
15271527
1528+ def __rmod__ (self , other ) -> TyArrayBase :
1529+ if isinstance (other , int | float ):
1530+ b , a = promote (self , other )
1531+ return a .__mod__ (b )
1532+ return NotImplemented
1533+
15281534 def __pow__ (self , other : TyArrayBase | PyScalar ) -> TyArrayBase :
15291535 return self ._apply (other , op .pow , forward = True , result_type = TyArrayNumber )
15301536
@@ -1697,15 +1703,16 @@ def __or__(self, other) -> TyArrayBase:
16971703 def __ror__ (self , other ) -> TyArrayBase :
16981704 return self ._apply_int_only (other , op .bitwise_or , forward = False )
16991705
1700- def __mod__ (self , other ) -> TyArrayInteger :
1701- return self . _apply_int_only (
1702- other , lambda a , b : op .mod (a , b , fmod = 0 ), forward = True
1703- )
1706+ def __mod__ (self , other ) -> TyArrayBase :
1707+ if isinstance ( other , type ( self )):
1708+ var = op .mod (self . _var , other . _var , fmod = 0 )
1709+ return safe_cast ( type ( self ), _var_to_tyarray ( var ) )
17041710
1705- def __rmod__ (self , other ) -> TyArrayInteger :
1706- return self ._apply_int_only (
1707- other , lambda a , b : op .mod (a , b , fmod = 0 ), forward = False
1708- )
1711+ if isinstance (other , TyArrayNumber | int | float ):
1712+ a , b = promote (self , other )
1713+ return a .__mod__ (b )
1714+
1715+ return NotImplemented
17091716
17101717 def __xor__ (self , other ) -> TyArrayBase :
17111718 return self ._apply_int_only (other , op .bitwise_xor , forward = True )
@@ -1884,33 +1891,30 @@ def __floordiv__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
18841891 return NotImplemented
18851892
18861893 def __mod__ (self , other ) -> TyArrayBase :
1887- if isinstance (other , TyArrayFloating | float ):
1894+ if isinstance (other , type ( self ) ):
18881895 # This function is complicated for two reasons:
18891896 # 1. The ONNX standard is undefined if dividend is 0, but the array-api is not.
18901897 # 2. The array-api follows the Python semantics, which are rather odd.
1891- a , b = promote (self , other )
1892- var = op .mod (a ._var , b ._var , fmod = 1 )
1898+ var = op .mod (self ._var , other ._var , fmod = 1 )
18931899 mod = safe_cast (TyArrayFloating , _var_to_tyarray (var ))
18941900 # NOTE: onnxruntime appears to have a bug where the sign
18951901 # of zeros is only preserved if they are on the
18961902 # false-branch!
18971903 # TODO: File a bug!
1898- fixed_mod = where ((b < 0 ) == (mod < 0 ), mod , mod + b )
1899- fixed_zeros = where (b > 0 , const (0.0 , mod .dtype ), const (- 0.0 , mod .dtype ))
1904+ fixed_mod = where ((other < 0 ) == (mod < 0 ), mod , mod + other )
1905+ fixed_zeros = where (
1906+ other > 0 , const (0.0 , mod .dtype ), const (- 0.0 , mod .dtype )
1907+ )
19001908 return where (
1901- safe_cast (TyArrayBool , ~ ((mod == 0.0 ) & (b != 0.0 ))),
1909+ safe_cast (TyArrayBool , ~ ((mod == 0.0 ) & (other != 0.0 ))),
19021910 fixed_mod ,
19031911 fixed_zeros ,
19041912 )
1913+ if isinstance (other , TyArrayNumber | int | float ):
1914+ a , b = promote (self , other )
1915+ return a .__mod__ (b )
19051916
1906- return super ().__mod__ (other )
1907-
1908- def __rmod__ (self , other ) -> TyArrayBase :
1909- if isinstance (other , TyArrayFloating | float ):
1910- b , a = promote (self , other )
1911- var = op .mod (a ._var , b ._var , fmod = 1 )
1912- return safe_cast (TyArrayFloating , _var_to_tyarray (var ))
1913- return super ().__mod__ (other )
1917+ return NotImplemented
19141918
19151919 def __ndx_logaddexp__ (self , x2 : TyArrayBase | int | float , / ) -> TyArrayFloating :
19161920 if isinstance (x2 , TyArrayNumber | int | float ):
0 commit comments