@@ -51,31 +51,33 @@ def create_mixed_precision_metadata(
5151 acc_dtype : torch .dtype | None = None ,
5252 meta_dtype : torch .dtype | None = None ,
5353 scaled_activations : bool = False ,
54+ strict : bool = False ,
5455) -> MixedPrecisionMatmulMetadata :
5556 """Verify sizes and dtypes of tensors and deduce metadata parameters."""
56- expected_input_matrix_rank : Final = 2
57+ if strict :
58+ expected_input_matrix_rank : Final = 2
5759
58- if (x_rank := len (x .shape )) != expected_input_matrix_rank :
59- error_msg = f"Unexpected number of dimensions of input tensor x: { x_rank } "
60- raise ValueError (error_msg )
60+ if (x_rank := len (x .shape )) != expected_input_matrix_rank :
61+ error_msg = f"Unexpected number of dimensions of input tensor x: { x_rank } "
62+ raise ValueError (error_msg )
6163
62- if (w_q_packed_rank := len (w_q_packed .shape )) != expected_input_matrix_rank :
63- error_msg = f"Unexpected number of dimensions of input tensor w_q_packed: { w_q_packed_rank } "
64- raise ValueError (error_msg )
64+ if (w_q_packed_rank := len (w_q_packed .shape )) != expected_input_matrix_rank :
65+ error_msg = f"Unexpected number of dimensions of input tensor w_q_packed: { w_q_packed_rank } "
66+ raise ValueError (error_msg )
6567
66- if (w_s_rank := len (w_s .shape )) != expected_input_matrix_rank :
67- error_msg = f"Unexpected number of dimensions of input tensor w_s: { w_s_rank } "
68- raise ValueError (error_msg )
68+ if (w_s_rank := len (w_s .shape )) != expected_input_matrix_rank :
69+ error_msg = f"Unexpected number of dimensions of input tensor w_s: { w_s_rank } "
70+ raise ValueError (error_msg )
6971
70- if w_zp is not None and (w_zp_rank := len (w_zp .shape )) != expected_input_matrix_rank :
71- error_msg = f"Unexpected number of dimensions of input tensor w_zp: { w_zp_rank } "
72- raise ValueError (error_msg )
72+ if w_zp is not None and (w_zp_rank := len (w_zp .shape )) != expected_input_matrix_rank :
73+ error_msg = f"Unexpected number of dimensions of input tensor w_zp: { w_zp_rank } "
74+ raise ValueError (error_msg )
7375
74- # Expecting some form of 32-bit packing
75- expected_packed_dtypes : Final = [torch .uint32 , torch .int32 ]
76- if (packed_dtype := w_q_packed .dtype ) not in expected_packed_dtypes :
77- error_msg = f"Invalid datatype for packed weights: { packed_dtype } "
78- raise ValueError (error_msg )
76+ # Expecting some form of 32-bit packing
77+ expected_packed_dtypes : Final = [torch .uint32 , torch .int32 ]
78+ if (packed_dtype := w_q_packed .dtype ) not in expected_packed_dtypes :
79+ error_msg = f"Invalid datatype for packed weights: { packed_dtype } "
80+ raise ValueError (error_msg )
7981
8082 # Assume 32-bit packing
8183 packed_bitwidth : Final = 32
@@ -86,25 +88,27 @@ def create_mixed_precision_metadata(
8688
8789 unpack_mask = 2 ** weight_size_bits - 1
8890
89- # Verify shape of w_s
90- expected_scales_shape : Final = (k_dim // group_size , n_dim )
91- if (scales_shape := w_s .shape ) != expected_scales_shape :
92- error_msg = f"Invalid w_s shape (expected: { expected_scales_shape } , actual: { scales_shape } )"
93- raise ValueError (error_msg )
94-
9591 # Check if zeros is a scalar value
9692 zero_is_scalar = False if w_zp is None else w_zp .numel () == 1
97- # Expected shape of zeros tensor if 1) it is not scalar 2) it is not None
98- expected_zeros_shape : Final = (k_dim // group_size , n_dim )
99- # Verify shape of w_zp
100- if not zero_is_scalar and w_zp is not None and (zeros_shape := w_zp .shape ) != expected_zeros_shape :
101- error_msg = f"Invalid w_zp shape (expected: { expected_zeros_shape } , actual: { zeros_shape } )"
102- raise ValueError (error_msg )
103-
104- # Not supporting scaled activations right now, but we can add support later if needed. This simplifies the interface
105- if scaled_activations :
106- error_msg = "Scaled activations not yet implemented (need to deduce correct channel_scale_mode)"
107- raise NotImplementedError (error_msg )
93+
94+ if strict :
95+ # Verify shape of w_s
96+ expected_scales_shape : Final = (k_dim // group_size , n_dim )
97+ if (scales_shape := w_s .shape ) != expected_scales_shape :
98+ error_msg = f"Invalid w_s shape (expected: { expected_scales_shape } , actual: { scales_shape } )"
99+ raise ValueError (error_msg )
100+
101+ # Expected shape of zeros tensor if 1) it is not scalar 2) it is not None
102+ expected_zeros_shape : Final = (k_dim // group_size , n_dim )
103+ # Verify shape of w_zp
104+ if not zero_is_scalar and w_zp is not None and (zeros_shape := w_zp .shape ) != expected_zeros_shape :
105+ error_msg = f"Invalid w_zp shape (expected: { expected_zeros_shape } , actual: { zeros_shape } )"
106+ raise ValueError (error_msg )
107+
108+ # Not supporting scaled activations right now, but we can add support later if needed. This simplifies the interface
109+ if scaled_activations :
110+ error_msg = "Scaled activations not yet implemented (need to deduce correct channel_scale_mode)"
111+ raise NotImplementedError (error_msg )
108112
109113 return MixedPrecisionMatmulMetadata (
110114 m_dim = m_dim ,
@@ -139,6 +143,7 @@ def mixed_precision_gemm(
139143 acc_dtype : torch .dtype | None = None ,
140144 meta_dtype : torch .dtype | None = None ,
141145 scaled_activations : bool = False ,
146+ strict : bool = False ,
142147) -> torch .Tensor :
143148 """Mixed precision GEMM operation."""
144149 metadata = create_mixed_precision_metadata (
@@ -153,6 +158,7 @@ def mixed_precision_gemm(
153158 acc_dtype = acc_dtype ,
154159 meta_dtype = meta_dtype ,
155160 scaled_activations = scaled_activations ,
161+ strict = strict ,
156162 )
157163
158164 output = torch .zeros ((metadata .m_dim , metadata .n_dim ), device = x .device , dtype = metadata .output_dtype )
@@ -168,42 +174,45 @@ def create_scaled_metadata(
168174 scale_a : torch .Tensor ,
169175 scale_b : torch .Tensor ,
170176 output_dtype : torch .dtype ,
177+ strict : bool = False ,
171178) -> ScaledMatmulMetadata :
172179 """Verify sizes and dtypes of tensors and deduce metadata parameters."""
173- expected_input_matrix_rank : Final = 2
180+ if strict :
181+ expected_input_matrix_rank : Final = 2
174182
175- if (a_rank := len (a .shape )) != expected_input_matrix_rank :
176- error_msg = f"Unexpected number of dimensions of input tensor a: { a_rank } "
177- raise ValueError (error_msg )
183+ if (a_rank := len (a .shape )) != expected_input_matrix_rank :
184+ error_msg = f"Unexpected number of dimensions of input tensor a: { a_rank } "
185+ raise ValueError (error_msg )
178186
179- if (b_rank := len (b .shape )) != expected_input_matrix_rank :
180- error_msg = f"Unexpected number of dimensions of input tensor b: { b_rank } "
181- raise ValueError (error_msg )
187+ if (b_rank := len (b .shape )) != expected_input_matrix_rank :
188+ error_msg = f"Unexpected number of dimensions of input tensor b: { b_rank } "
189+ raise ValueError (error_msg )
182190
183- if a .dtype != b .dtype :
184- error_msg = f"Input tensors a and b must have the same datatype (a: { a .dtype } , b: { b .dtype } )"
185- raise ValueError (error_msg )
191+ if a .dtype != b .dtype :
192+ error_msg = f"Input tensors a and b must have the same datatype (a: { a .dtype } , b: { b .dtype } )"
193+ raise ValueError (error_msg )
186194
187195 m_dim , k_dim = a .shape
188196 _ , n_dim = b .shape
189197
190- if scale_a .numel () != 1 :
191- if (scale_a_rank := len (scale_a .shape )) != expected_input_matrix_rank :
192- error_msg = f"Unexpected number of dimensions of input tensor scale_a: { scale_a_rank } "
193- raise ValueError (error_msg )
198+ if strict :
199+ if scale_a .numel () != 1 :
200+ if (scale_a_rank := len (scale_a .shape )) != expected_input_matrix_rank :
201+ error_msg = f"Unexpected number of dimensions of input tensor scale_a: { scale_a_rank } "
202+ raise ValueError (error_msg )
194203
195- if scale_a .shape [0 ] != m_dim :
196- error_msg = f"Invalid scale_a shape (expected: ({ m_dim } ,), actual: { scale_a .shape } )"
197- raise ValueError (error_msg )
204+ if scale_a .shape [0 ] != m_dim :
205+ error_msg = f"Invalid scale_a shape (expected: ({ m_dim } ,), actual: { scale_a .shape } )"
206+ raise ValueError (error_msg )
198207
199- if scale_b .numel () != 1 :
200- if (scale_b_rank := len (scale_b .shape )) != expected_input_matrix_rank :
201- error_msg = f"Unexpected number of dimensions of input tensor scale_b: { scale_b_rank } "
202- raise ValueError (error_msg )
208+ if scale_b .numel () != 1 :
209+ if (scale_b_rank := len (scale_b .shape )) != expected_input_matrix_rank :
210+ error_msg = f"Unexpected number of dimensions of input tensor scale_b: { scale_b_rank } "
211+ raise ValueError (error_msg )
203212
204- if scale_b .shape [0 ] != n_dim :
205- error_msg = f"Invalid scale_b shape (expected: ({ n_dim } ,), actual: { scale_b .shape } )"
206- raise ValueError (error_msg )
213+ if scale_b .shape [0 ] != n_dim :
214+ error_msg = f"Invalid scale_b shape (expected: ({ n_dim } ,), actual: { scale_b .shape } )"
215+ raise ValueError (error_msg )
207216
208217 return ScaledMatmulMetadata (
209218 m_dim = m_dim ,
@@ -228,9 +237,10 @@ def scaled_gemm(
228237 scale_b : torch .Tensor ,
229238 output_dtype : torch .dtype ,
230239 bias : torch .Tensor | None = None ,
240+ strict : bool = False ,
231241) -> torch .Tensor :
232242 """Scaled GEMM operation."""
233- metadata = create_scaled_metadata (a , b , scale_a , scale_b , output_dtype )
243+ metadata = create_scaled_metadata (a , b , scale_a , scale_b , output_dtype , strict = strict )
234244
235245 output = torch .zeros ((metadata .m_dim , metadata .n_dim ), device = a .device , dtype = output_dtype )
236246
0 commit comments