Skip to content

Commit 6d8ca1b

Browse files
fix(lint): Resolve Pylint errors in quantizations.py and nnx_wrappers.py
1 parent f20b375 commit 6d8ca1b

2 files changed

Lines changed: 18 additions & 8 deletions

File tree

src/maxtext/layers/nnx_wrappers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,15 +498,17 @@ def maybe_unbox(x):
498498
paths_str = ""
499499
for path, _ in unknown_state_flat.items():
500500
paths_str += f"\n - {'/'.join(map(str, path))}"
501-
501+
502502
# Dynamically reconstruct the unknown variables
503503
curr = module
504504
for p in path[:-1]:
505505
if not hasattr(curr, p):
506506
setattr(curr, p, nnx.Module())
507507
curr = getattr(curr, p)
508508

509-
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed.")
509+
warnings.warn(
510+
f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed."
511+
)
510512

511513
nnx.update(module, new_state)
512514
_refresh_variable_trace_state(module)

src/maxtext/layers/quantizations.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@
3838
from flax.linen import fp8_ops
3939
from flax.linen import initializers as flax_initializers
4040
import flax.linen as nn
41+
from flax import nnx
42+
43+
from qwix._src import flax_util
4144

4245
from maxtext.common.common_types import DType, Config
4346
from maxtext.inference.kvcache import KVQuant
47+
from maxtext.layers import nnx_wrappers
4448

4549
# Params used to define mixed precision quantization configs
4650
DEFAULT = "__default__" # default config
@@ -708,18 +712,16 @@ def configure_kv_quant(config):
708712

709713

710714
def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
715+
"""Applies a Linen module within an NNX context."""
711716
try:
712-
from qwix._src import flax_util
713717
parent = flax_util.get_current_module()
714-
from flax import nnx
715718
is_nnx = isinstance(parent, nnx.Module)
716-
except Exception:
719+
except Exception: # pylint: disable=broad-exception-caught
717720
is_nnx = False
718721

719722
if is_nnx:
720723
attr_name = f"_qwix_fp8_gpu_{op_id}"
721724
if not hasattr(parent, attr_name):
722-
from maxtext.layers import nnx_wrappers
723725
rngs = getattr(parent, "qwix_rngs", None)
724726
if rngs is None:
725727
rngs = nnx.Rngs(0)
@@ -838,11 +840,17 @@ def maybe_quantize_model(model, config):
838840
if quantization_provider:
839841
if config.pure_nnx:
840842
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
841-
import jax.numpy as jnp
842843
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
843844
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
844845
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
845-
model = qwix.quantize_model(model, quantization_provider, dummy_tokens, dummy_positions, dummy_segment_ids, enable_dropout=False)
846+
model = qwix.quantize_model(
847+
model,
848+
quantization_provider,
849+
dummy_tokens,
850+
dummy_positions,
851+
dummy_segment_ids,
852+
enable_dropout=False,
853+
)
846854
else:
847855
model = qwix.quantize_model(model, quantization_provider)
848856
return model

0 commit comments

Comments
 (0)