22
33import ninetoothed
44import ninetoothed .language as ntl
5- import torch
65from ninetoothed import Tensor
76
8- from ntops import element_wise
7+ from ntops . kernels . element_wise import arrangement
98
109
1110def default_application (input , other , output ):
@@ -20,26 +19,15 @@ def floor_application(input, other, output):
2019 output = ntl .floor (input / other ) # noqa: F841
2120
2221
23- def div (input , other , rounding_mode = None , output = None ):
24- if output is None :
25- output = torch .empty_like (input )
26-
27- kernel = _make (input .ndim , rounding_mode )
28-
29- kernel (input , other , output )
30-
31- return output
32-
33-
3422@functools .cache
35- def _make (ndim , rounding_mode ):
36- tensors = (Tensor (ndim ), Tensor (ndim ), Tensor (ndim ))
37-
23+ def make (ndim , rounding_mode ):
3824 if rounding_mode == "trunc" :
3925 application = trunc_application
4026 elif rounding_mode == "floor" :
4127 application = floor_application
4228 else :
4329 application = default_application
4430
45- return ninetoothed .make (element_wise .arrangement , application , tensors )
31+ tensors = (Tensor (ndim ), Tensor (ndim ), Tensor (ndim ))
32+
33+ return ninetoothed .make (arrangement , application , tensors )
0 commit comments