11import functools
22
3- import ninetoothed
43from ninetoothed import Tensor
54
65from ntops .kernels .mm import BLOCK_SIZE_K , BLOCK_SIZE_M , BLOCK_SIZE_N , application
76
87
9- def arrangement (input , other , output ):
10- output_arranged = output .tile ((1 , BLOCK_SIZE_M , BLOCK_SIZE_N ))
8+ def arrangement (
9+ input , other , output , block_size_m = None , block_size_n = None , block_size_k = None
10+ ):
11+ if block_size_m is None :
12+ block_size_m = BLOCK_SIZE_M
13+
14+ if block_size_n is None :
15+ block_size_n = BLOCK_SIZE_N
16+
17+ if block_size_k is None :
18+ block_size_k = BLOCK_SIZE_K
19+
20+ output_arranged = output .tile ((1 , block_size_m , block_size_n ))
1121 output_arranged .dtype = output_arranged .dtype .squeeze (0 )
1222
13- input_arranged = input .tile ((1 , BLOCK_SIZE_M , BLOCK_SIZE_K ))
23+ input_arranged = input .tile ((1 , block_size_m , block_size_k ))
1424 input_arranged = input_arranged .tile ((1 , 1 , - 1 ))
1525 input_arranged = input_arranged .expand ((- 1 , - 1 , output_arranged .shape [- 1 ]))
1626 input_arranged .dtype = input_arranged .dtype .squeeze ((0 , 1 ))
1727 input_arranged .dtype .dtype = input_arranged .dtype .dtype .squeeze (0 )
1828
19- other_arranged = other .tile ((1 , BLOCK_SIZE_K , BLOCK_SIZE_N ))
29+ other_arranged = other .tile ((1 , block_size_k , block_size_n ))
2030 other_arranged = other_arranged .tile ((1 , - 1 , 1 ))
2131 other_arranged = other_arranged .expand ((- 1 , output_arranged .shape [- 2 ], - 1 ))
2232 other_arranged .dtype = other_arranged .dtype .squeeze ((0 , 2 ))
@@ -25,6 +35,14 @@ def arrangement(input, other, output):
2535 return input_arranged , other_arranged , output_arranged
2636
2737
28- @functools .cache
29- def make ():
30- return ninetoothed .make (arrangement , application , (Tensor (3 ), Tensor (3 ), Tensor (3 )))
38+ def premake (dtype = None , block_size_m = None , block_size_n = None , block_size_k = None ):
39+ arrangement_ = functools .partial (
40+ arrangement ,
41+ block_size_m = block_size_m ,
42+ block_size_n = block_size_n ,
43+ block_size_k = block_size_k ,
44+ )
45+
46+ tensors = (Tensor (3 , dtype = dtype ), Tensor (3 , dtype = dtype ), Tensor (3 , dtype = dtype ))
47+
48+ return arrangement_ , application , tensors
0 commit comments