Skip to content

Commit 135d9bc

Browse files
fix(lint): Resolve Pylint errors in train.py and quantizations_test.py
1 parent 6d8ca1b commit 135d9bc

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,9 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
349349
is_train=True,
350350
)
351351
else:
352-
OverwriteWithGradient = nnx.variablelib.variable_type_from_name(maxtext_utils.OVERWRITE_WITH_GRADIENT, allow_register=True)
352+
OverwriteWithGradient = nnx.variablelib.variable_type_from_name(
353+
maxtext_utils.OVERWRITE_WITH_GRADIENT, allow_register=True
354+
)
353355
model_graphdef, curr_params, overwrite_vars, rest = nnx.split(state.model, nnx.Param, OverwriteWithGradient, ...)
354356
if config.parameter_memory_host_offload:
355357
# Params are kept on host (pinned_host) in in_shardings. Move only Param
@@ -379,7 +381,9 @@ def diff_wrapper(param, overwrite_vars, rest, config, data):
379381
return loss, (aux, new_overwrite_vars, new_rest)
380382

381383
grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
382-
(loss, (aux, new_overwrite_vars, new_rest)), (raw_grads, overwrite_grads) = grad_func(curr_params, overwrite_vars, rest, config, data)
384+
(loss, (aux, _, new_rest)), (raw_grads, overwrite_grads) = grad_func(
385+
curr_params, overwrite_vars, rest, config, data
386+
)
383387
nnx.update(state.model, new_rest)
384388
nnx.update(state.model, overwrite_grads)
385389

tests/unit/quantizations_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from aqt.jax.v2 import aqt_tensor
2323
from aqt.jax.v2.flax import aqt_flax
2424
from flax import nnx
25+
from flax.nnx import traversals
2526
import jax
2627
from jax import lax
2728
from jax import numpy as jnp
@@ -48,7 +49,7 @@ def __init__(
4849
self,
4950
quantization: quantizations.AqtQuantization,
5051
data_type: Any,
51-
rngs: nnx.Rngs,
52+
rngs: nnx.Rngs, # pylint: disable=unused-argument
5253
):
5354
self.quantization = quantization
5455
self.identity = jnp.identity(2, dtype=data_type)
@@ -441,10 +442,9 @@ def loss_quant(model):
441442

442443
# nnx.grad returns a State object which is a mapping of paths to gradients.
443444
# Flatten them to check for tolerance.
444-
from flax.nnx import traversals
445445
grads_base_flat = traversals.flatten_mapping(grads_base)
446446
grads_quant_flat = traversals.flatten_mapping(grads_quant)
447-
447+
448448
# Filter for param collections to compare only parameters and not stats/buffers if any
449449
# Note: NNX grads structure might contain variables like 'kernel', 'bias'.
450450
# For simplicity we compare all matching keys.
@@ -557,7 +557,7 @@ def test_fp8_gpu_quantization(self):
557557

558558
# @pytest.mark.gpu_only
559559
@pytest.mark.external_serving
560-
def test_fp8_gpu_quantization(self):
560+
def test_fp8_gpu_quantization_nnx(self):
561561
self.quantization_config("fp8_gpu", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True)
562562

563563
@pytest.mark.gpu_only

0 commit comments

Comments
 (0)