3030 ((10 , 5 , 1024 ), (3072 , 1024 ), (3072 ,), False ),
3131]
3232
33+ # Alpha test cases: (x_shape, weight_shape, bias_shape, bias, alpha)
34+ _ALPHA_TEST_CASES_DATA = [
35+ ((2 , 5 , 256 ), (512 , 256 ), (512 ,), True , 2.5 ),
36+ ((2 , 5 , 256 ), (512 , 256 ), (512 ,), False , 0.5 ),
37+ ((1 , 1024 ), (3072 , 1024 ), (3072 ,), True , 1.0 ),
38+ ]
39+
3340# Tolerance configuration
3441_TOLERANCE_MAP = {
3542 infinicore .float16 : {"atol" : 0 , "rtol" : 1e-2 },
@@ -74,6 +81,25 @@ def parse_test_cases():
7481 )
7582 )
7683
84+ # Alpha test cases
85+ for x_shape , weight_shape , bias_shape , has_bias , alpha in _ALPHA_TEST_CASES_DATA :
86+ for dtype in _TENSOR_DTYPES :
87+ tolerance = _TOLERANCE_MAP .get (dtype , {"atol" : 0 , "rtol" : 1e-3 })
88+ x_spec = TensorSpec .from_tensor (x_shape , None , dtype , name = "x" )
89+ weight_spec = TensorSpec .from_tensor (weight_shape , None , dtype , name = "weight" )
90+ bias_spec = TensorSpec .from_tensor (bias_shape , None , dtype , name = "bias" )
91+
92+ test_cases .append (
93+ TestCase (
94+ inputs = [x_spec , weight_spec , bias_spec ],
95+ kwargs = {"has_bias" : has_bias , "alpha" : alpha },
96+ output_spec = None ,
97+ comparison_target = None ,
98+ tolerance = tolerance ,
99+ description = f"nn.Linear - ALPHA={ alpha } " ,
100+ )
101+ )
102+
77103 return test_cases
78104
79105
@@ -123,7 +149,7 @@ def __init__(self):
123149 def get_test_cases (self ):
124150 return parse_test_cases ()
125151
126- def torch_operator (self , x , weight , bias , has_bias ):
152+ def torch_operator (self , x , weight , bias , has_bias , alpha = None ):
127153 """PyTorch nn.Linear implementation"""
128154 out_features , in_features = weight .shape
129155 params_dict = {"l.weight" : weight }
@@ -141,9 +167,13 @@ def torch_operator(self, x, weight, bias, has_bias):
141167
142168 with torch .no_grad ():
143169 y = model (x )
170+ if alpha is not None :
171+ # alpha scales only matmul, not bias: alpha * (x @ W^T) + b
172+ y_matmul = torch .nn .functional .linear (x , weight )
173+ y = alpha * y_matmul + (bias if has_bias else 0 )
144174 return y
145175
146- def infinicore_operator (self , x , weight , bias , has_bias ):
176+ def infinicore_operator (self , x , weight , bias , has_bias , alpha = None ):
147177 """InfiniCore nn.Linear implementation"""
148178
149179 out_features , in_features = weight .shape
@@ -158,6 +188,8 @@ def infinicore_operator(self, x, weight, bias, has_bias):
158188 device = weight .device ,
159189 dtype = weight .dtype ,
160190 )
191+ if alpha is not None :
192+ model .l .alpha = alpha
161193 model .load_state_dict (params_dict )
162194
163195 y = model (x )
0 commit comments