Skip to content

Commit ce03e2f

Browse files
Test
1 parent 8bacbe0 commit ce03e2f

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

tests/unit/quantizations_test.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,14 @@ def compare_fn(path, x, y):
393393

394394
def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1):
395395
"""Run forward pass and backward pass for quantized model and compare with base model."""
396+
rngs = nnx.Rngs(0)
397+
396398
cfg = self.init_pyconfig(quantization=quant)
397-
model = model_creation_utils.create_model(self.cfg, self.mesh)
398-
qt_model = model_creation_utils.create_model(cfg, self.mesh)
399+
model = model_creation_utils.create_model(self.cfg, self.mesh, rngs=rngs)
400+
qt_model = model_creation_utils.create_model(cfg, self.mesh, rngs=rngs)
399401

400402
ids, decoder_segment_ids, decoder_positions = self.get_data()
403+
'''
401404
var = model.init(
402405
{"params": self.rng, "aqt": self.rng, "dropout": self.rng},
403406
ids,
@@ -414,7 +417,8 @@ def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1)
414417
enable_dropout=False,
415418
mutable=True,
416419
)
417-
420+
'''
421+
418422
def loss_base(all_vars, inputs):
419423
logits, _ = model.apply(
420424
all_vars,
@@ -438,18 +442,18 @@ def loss_quant(all_vars, inputs):
438442
# Compute gradients w.r.t. both models
439443
grads_base = jax.grad(loss_base)(var, (ids, decoder_positions, decoder_segment_ids))
440444
grads_quant = jax.grad(loss_quant)(quantized_vars, (ids, decoder_positions, decoder_segment_ids))
441-
442-
logits, _ = model.apply(
443-
var,
445+
446+
logits, _ = model( # model.apply(
447+
# var,
444448
ids,
445449
decoder_positions,
446450
decoder_segment_ids,
447451
enable_dropout=False,
448452
rngs={"params": self.rng},
449453
mutable=True,
450454
)
451-
quant_logits, _ = qt_model.apply(
452-
quantized_vars,
455+
quant_logits, _ = qt_model( # qt_model.apply(
456+
# quantized_vars,
453457
ids,
454458
decoder_positions,
455459
decoder_segment_ids,

0 commit comments

Comments
 (0)