@@ -393,11 +393,14 @@ def compare_fn(path, x, y):
393393
394394 def quantization_config (self , quant , logits_tolerance = 2e-1 , grad_tolerance = 5e-1 ):
395395 """Run forward pass and backward pass for quantized model and compare with base model."""
396+ rngs = nnx .Rngs (0 )
397+
396398 cfg = self .init_pyconfig (quantization = quant )
397- model = model_creation_utils .create_model (self .cfg , self .mesh )
398- qt_model = model_creation_utils .create_model (cfg , self .mesh )
399+ model = model_creation_utils .create_model (self .cfg , self .mesh , rngs = rngs )
400+ qt_model = model_creation_utils .create_model (cfg , self .mesh , rngs = rngs )
399401
400402 ids , decoder_segment_ids , decoder_positions = self .get_data ()
403+ '''
401404 var = model.init(
402405 {"params": self.rng, "aqt": self.rng, "dropout": self.rng},
403406 ids,
@@ -414,7 +417,8 @@ def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1)
414417 enable_dropout=False,
415418 mutable=True,
416419 )
417-
420+ '''
421+
418422 def loss_base (all_vars , inputs ):
419423 logits , _ = model .apply (
420424 all_vars ,
@@ -438,18 +442,18 @@ def loss_quant(all_vars, inputs):
438442 # Compute gradients w.r.t. both models
439443 grads_base = jax .grad (loss_base )(var , (ids , decoder_positions , decoder_segment_ids ))
440444 grads_quant = jax .grad (loss_quant )(quantized_vars , (ids , decoder_positions , decoder_segment_ids ))
441-
442- logits , _ = model .apply (
443- var ,
445+
446+ logits , _ = model ( # model .apply(
447+ # var,
444448 ids ,
445449 decoder_positions ,
446450 decoder_segment_ids ,
447451 enable_dropout = False ,
448452 rngs = {"params" : self .rng },
449453 mutable = True ,
450454 )
451- quant_logits , _ = qt_model .apply (
452- quantized_vars ,
455+ quant_logits , _ = qt_model ( # qt_model .apply(
456+ # quantized_vars,
453457 ids ,
454458 decoder_positions ,
455459 decoder_segment_ids ,
0 commit comments