Skip to content

Commit 239e912

Browse files
Fix Linting
1 parent f20b375 commit 239e912

3 files changed

Lines changed: 14 additions & 5 deletions

File tree

src/maxtext/layers/nnx_wrappers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,15 +498,17 @@ def maybe_unbox(x):
498498
paths_str = ""
499499
for path, _ in unknown_state_flat.items():
500500
paths_str += f"\n - {'/'.join(map(str, path))}"
501-
501+
502502
# Dynamically reconstruct the unknown variables
503503
curr = module
504504
for p in path[:-1]:
505505
if not hasattr(curr, p):
506506
setattr(curr, p, nnx.Module())
507507
curr = getattr(curr, p)
508508

509-
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed.")
509+
warnings.warn(
510+
f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed."
511+
)
510512

511513
nnx.update(module, new_state)
512514
_refresh_variable_trace_state(module)

src/maxtext/layers/quantizations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,10 @@ def configure_kv_quant(config):
710710
def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
711711
try:
712712
from qwix._src import flax_util
713+
713714
parent = flax_util.get_current_module()
714715
from flax import nnx
716+
715717
is_nnx = isinstance(parent, nnx.Module)
716718
except Exception:
717719
is_nnx = False
@@ -720,6 +722,7 @@ def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
720722
attr_name = f"_qwix_fp8_gpu_{op_id}"
721723
if not hasattr(parent, attr_name):
722724
from maxtext.layers import nnx_wrappers
725+
723726
rngs = getattr(parent, "qwix_rngs", None)
724727
if rngs is None:
725728
rngs = nnx.Rngs(0)
@@ -839,10 +842,13 @@ def maybe_quantize_model(model, config):
839842
if config.pure_nnx:
840843
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
841844
import jax.numpy as jnp
845+
842846
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
843847
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
844848
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
845-
model = qwix.quantize_model(model, quantization_provider, dummy_tokens, dummy_positions, dummy_segment_ids, enable_dropout=False)
849+
model = qwix.quantize_model(
850+
model, quantization_provider, dummy_tokens, dummy_positions, dummy_segment_ids, enable_dropout=False
851+
)
846852
else:
847853
model = qwix.quantize_model(model, quantization_provider)
848854
return model

tests/unit/quantizations_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,10 @@ def loss_quant(model):
442442
# nnx.grad returns a State object which is a mapping of paths to gradients.
443443
# Flatten them to check for tolerance.
444444
from flax.nnx import traversals
445+
445446
grads_base_flat = traversals.flatten_mapping(grads_base)
446447
grads_quant_flat = traversals.flatten_mapping(grads_quant)
447-
448+
448449
# Filter for param collections to compare only parameters and not stats/buffers if any
449450
# Note: NNX grads structure might contain variables like 'kernel', 'bias'.
450451
# For simplicity we compare all matching keys.
@@ -555,7 +556,7 @@ def test_fp8_full_quantization(self):
555556
def test_fp8_gpu_quantization(self):
556557
self.quantization_config("fp8_gpu", grad_tolerance=1.5)
557558

558-
# @pytest.mark.gpu_only
559+
@pytest.mark.gpu_only
559560
@pytest.mark.external_serving
560561
def test_fp8_gpu_quantization(self):
561562
self.quantization_config("fp8_gpu", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True)

0 commit comments

Comments
 (0)