@@ -13,6 +13,7 @@ def arrangement(
1313 beta ,
1414 alpha ,
1515 output ,
16+ input_precision ,
1617 block_size_m = None ,
1718 block_size_n = None ,
1819 block_size_k = None ,
@@ -26,27 +27,46 @@ def arrangement(
2627 if block_size_k is None :
2728 block_size_k = mm .BLOCK_SIZE_K
2829
29- _ , _ , input_arranged = mm .arrangement (
30+ _ , _ , input_arranged , _ = mm .arrangement (
3031 x ,
3132 y ,
3233 input ,
34+ input_precision ,
3335 block_size_m = block_size_m ,
3436 block_size_n = block_size_n ,
3537 block_size_k = block_size_k ,
3638 )
3739
38- x_arranged , y_arranged , output_arranged = mm .arrangement (x , y , output )
40+ x_arranged , y_arranged , output_arranged , _ = mm .arrangement (
41+ x , y , output , input_precision
42+ )
43+
44+ input_precision_arranged = input_precision
3945
40- return input_arranged , x_arranged , y_arranged , beta , alpha , output_arranged
46+ return (
47+ input_arranged ,
48+ x_arranged ,
49+ y_arranged ,
50+ beta ,
51+ alpha ,
52+ output_arranged ,
53+ input_precision_arranged ,
54+ )
4155
4256
43- def application (input , x , y , beta , alpha , output ):
57+ def application (input , x , y , beta , alpha , output , input_precision ):
4458 mm_output = ntl .zeros (output .shape , dtype = ntl .float32 )
45- mm .application (x , y , mm_output )
59+ mm .application (x , y , mm_output , input_precision )
4660 output = beta * input + alpha * mm_output
4761
4862
49- def premake (dtype = None , block_size_m = None , block_size_n = None , block_size_k = None ):
63+ def premake (
64+ input_precision = None ,
65+ dtype = None ,
66+ block_size_m = None ,
67+ block_size_n = None ,
68+ block_size_k = None ,
69+ ):
5070 arrangement_ = functools .partial (
5171 arrangement ,
5272 block_size_m = block_size_m ,
@@ -61,6 +81,7 @@ def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None)
6181 Tensor (0 , dtype = dtype ),
6282 Tensor (0 , dtype = dtype ),
6383 Tensor (2 , dtype = dtype ),
84+ Tensor (0 , dtype = dtype , constexpr = True , value = input_precision ),
6485 )
6586
6687 return arrangement_ , application , tensors
0 commit comments