@@ -135,37 +135,36 @@ def test_shard_linear(self):
135135 self .assertTrue (mx .allclose (y , y2 , atol = self .atol , rtol = self .rtol ))
136136 self .assertTrue (mx .allclose (y [part ], y1 , atol = self .atol , rtol = self .rtol ))
137137
138- # And their quant versions (QuantizedMatmul is not supported on CUDA)
139- if not mx .cuda .is_available ():
140- qlin = lin .to_quantized ()
141- slin1 = shard_linear (qlin , "all-to-sharded" )
142- slin2 = shard_linear (qlin , "sharded-to-all" )
143- y = qlin (x )
144- y1 = slin1 (x )
145- y2 = slin2 (x [part ])
146- self .assertTrue (mx .allclose (y , y2 , atol = self .atol , rtol = self .rtol ))
147- self .assertTrue (mx .allclose (y [part ], y1 ))
148-
149- # Test non-affine quantization modes (mxfp8)
150- qlin_mxfp8 = lin .to_quantized (group_size = 32 , bits = 8 , mode = "mxfp8" )
151- self .assertEqual (qlin_mxfp8 .mode , "mxfp8" )
152-
153- slin1_mxfp8 = shard_linear (qlin_mxfp8 , "all-to-sharded" )
154- slin2_mxfp8 = shard_linear (qlin_mxfp8 , "sharded-to-all" )
155-
156- # Verify mode is propagated
157- self .assertEqual (slin1_mxfp8 .mode , "mxfp8" )
158- self .assertEqual (slin2_mxfp8 .mode , "mxfp8" )
159-
160- # Verify biases parameter is not set for mxfp8
161- self .assertIsNone (slin1_mxfp8 .get ("biases" ))
162- self .assertIsNone (slin2_mxfp8 .get ("biases" ))
163-
164- y = qlin_mxfp8 (x )
165- y1 = slin1_mxfp8 (x )
166- y2 = slin2_mxfp8 (x [part ])
167- self .assertTrue (mx .allclose (y , y2 , atol = self .atol , rtol = self .rtol ))
168- self .assertTrue (mx .allclose (y [part ], y1 ))
138+ # And their quant versions
139+ qlin = lin .to_quantized ()
140+ slin1 = shard_linear (qlin , "all-to-sharded" )
141+ slin2 = shard_linear (qlin , "sharded-to-all" )
142+ y = qlin (x )
143+ y1 = slin1 (x )
144+ y2 = slin2 (x [part ])
145+ self .assertTrue (mx .allclose (y , y2 , atol = self .atol , rtol = self .rtol ))
146+ self .assertTrue (mx .allclose (y [part ], y1 ))
147+
148+ # Test non-affine quantization modes (mxfp8)
149+ qlin_mxfp8 = lin .to_quantized (group_size = 32 , bits = 8 , mode = "mxfp8" )
150+ self .assertEqual (qlin_mxfp8 .mode , "mxfp8" )
151+
152+ slin1_mxfp8 = shard_linear (qlin_mxfp8 , "all-to-sharded" )
153+ slin2_mxfp8 = shard_linear (qlin_mxfp8 , "sharded-to-all" )
154+
155+ # Verify mode is propagated
156+ self .assertEqual (slin1_mxfp8 .mode , "mxfp8" )
157+ self .assertEqual (slin2_mxfp8 .mode , "mxfp8" )
158+
159+ # Verify biases parameter is not set for mxfp8
160+ self .assertIsNone (slin1_mxfp8 .get ("biases" ))
161+ self .assertIsNone (slin2_mxfp8 .get ("biases" ))
162+
163+ y = qlin_mxfp8 (x )
164+ y1 = slin1_mxfp8 (x )
165+ y2 = slin2_mxfp8 (x [part ])
166+ self .assertTrue (mx .allclose (y , y2 , atol = self .atol , rtol = self .rtol ))
167+ self .assertTrue (mx .allclose (y [part ], y1 ))
169168
170169 # Check the backward works as expected
171170 def dummy_loss (model , x , y ):
0 commit comments