Skip to content

Commit 53f3052

Browse files
Fix Linting
1 parent c255340 commit 53f3052

3 files changed

Lines changed: 31 additions & 23 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from maxtext.layers import initializers, linears, mhc, normalizations, quantizations
4242
from maxtext.layers.attentions import Attention
4343
from maxtext.layers.embeddings import Embed, PositionalEmbedding, attend_on_embedding
44-
from maxtext.layers.engram import Engram, NgramHashMapping
4544
from maxtext.layers.normalizations import RMSNorm
4645
from maxtext.layers.quantizations import AqtQuantization as Quant
4746
from maxtext.models import (
@@ -328,11 +327,15 @@ def __init__(
328327
else:
329328
next_boundary = self._find_next_boundary(current_idx, config.first_num_dense_layers, config.engram_layers)
330329
chunk_name = f"dense_layers_{current_idx}_{next_boundary - 1}"
331-
setattr(self, chunk_name, self._create_scanned_layers(
332-
dense_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
333-
))
330+
setattr(
331+
self,
332+
chunk_name,
333+
self._create_scanned_layers(
334+
dense_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
335+
),
336+
)
334337
current_idx = next_boundary
335-
338+
336339
# 2. Create MoE Chunks (Direct setattr, NO nnx.Dict)
337340
current_idx = config.first_num_dense_layers
338341
while current_idx < config.num_decoder_layers:
@@ -343,9 +346,13 @@ def __init__(
343346
else:
344347
next_boundary = self._find_next_boundary(current_idx, config.num_decoder_layers, config.engram_layers)
345348
chunk_name = f"moe_layers_{current_idx}_{next_boundary - 1}"
346-
setattr(self, chunk_name, self._create_scanned_layers(
347-
moe_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
348-
))
349+
setattr(
350+
self,
351+
chunk_name,
352+
self._create_scanned_layers(
353+
moe_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
354+
),
355+
)
349356
current_idx = next_boundary
350357
else:
351358
# Standard DeepSeek logic when Engrams are disabled
@@ -374,7 +381,7 @@ def __init__(
374381
self.layers_remainder = RemattedGemma3Block(
375382
config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs
376383
) # pytype: disable=wrong-keyword-args
377-
elif self.is_gemma4: # <-- ADDED BLOCK
384+
elif self.is_gemma4: # <-- ADDED BLOCK
378385
attention_pattern_length = len(gemma4.GEMMA4_ATTENTION_PATTERN)
379386
scan_length = config.num_decoder_layers // attention_pattern_length
380387
num_remaining_layers = config.num_decoder_layers % attention_pattern_length
@@ -424,7 +431,7 @@ def __init__(
424431
layer_kwargs = {}
425432
if config.decoder_block == DecoderBlockType.GEMMA3:
426433
layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)}
427-
elif config.decoder_block == DecoderBlockType.GEMMA4: # <-- ADDED
434+
elif config.decoder_block == DecoderBlockType.GEMMA4: # <-- ADDED
428435
layer_kwargs = {"attention_type": gemma4.get_attention_type(layer_id=lyr)}
429436
elif config.decoder_block == DecoderBlockType.LLAMA4:
430437
layer_kwargs = {
@@ -932,16 +939,11 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
932939
def _apply_single_engram_layer(self, y, layer_name, *args, **kwargs):
933940
"""Applies a single, unscanned Engram layer."""
934941
layer = getattr(self, layer_name)
935-
942+
936943
decoder_input_tokens = kwargs.get("decoder_input_tokens")
937944
layer_kwargs = kwargs.get("layer_kwargs", {})
938945

939-
out = layer(
940-
y,
941-
*args,
942-
decoder_input_tokens=decoder_input_tokens,
943-
**layer_kwargs
944-
)
946+
out = layer(y, *args, decoder_input_tokens=decoder_input_tokens, **layer_kwargs)
945947
if isinstance(out, tuple):
946948
y = out[0]
947949
else:
@@ -997,7 +999,7 @@ def _apply_interleaved_scanned_layers(self, y, layer_prefix, start_idx, end_idx,
997999
chunk_name = f"{layer_prefix}_{current_idx}_{next_boundary - 1}"
9981000
chunk_stack = getattr(self, chunk_name)
9991001
scan_length = next_boundary - current_idx
1000-
1002+
10011003
y, chunk_stack = self._apply_layers_sequentially(
10021004
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
10031005
)
@@ -1046,7 +1048,7 @@ def __call__(
10461048
# Extract the bidirectional mask locally for layer configurations
10471049
bidirectional_mask = multimodal_input.bidirectional_mask if multimodal_input is not None else None
10481050

1049-
if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4): # <-- UPDATED
1051+
if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4): # <-- UPDATED
10501052
layer_kwargs["bidirectional_mask"] = bidirectional_mask
10511053

10521054
if attention_metadata is not None:
@@ -1071,7 +1073,13 @@ def __call__(
10711073
)
10721074

10731075
y = self._apply_interleaved_scanned_layers(
1074-
y, "moe_layers", cfg.first_num_dense_layers, cfg.num_decoder_layers, cfg.engram_layers, *layer_args, **common_kwargs
1076+
y,
1077+
"moe_layers",
1078+
cfg.first_num_dense_layers,
1079+
cfg.num_decoder_layers,
1080+
cfg.engram_layers,
1081+
*layer_args,
1082+
**common_kwargs,
10751083
)
10761084
else:
10771085
y, self.dense_layers = self._apply_layers_sequentially(
@@ -1123,7 +1131,7 @@ def __call__(
11231131
previous_chunk,
11241132
page_state,
11251133
slot,
1126-
)
1134+
)
11271135
else:
11281136
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
11291137
if scan_length > 0:
@@ -1303,6 +1311,7 @@ def pure_gemma_fn(graphdef, state_in, y_in):
13031311

13041312
return y
13051313

1314+
13061315
def decoder_as_linen(
13071316
config: Config,
13081317
mesh: Mesh,

src/maxtext/layers/quantizations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from aqt.jax.v2 import tiled_dot_general
2727
from aqt.jax.v2 import calibration
2828

29-
from maxtext.layers import nnx_wrappers
3029
import qwix
3130
from qwix._src.core import dot_general_qt
3231

tests/unit/nnx_decoder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,4 +533,4 @@ def test_different_random_seeds_produce_different_logits(self):
533533

534534

535535
if __name__ == "__main__":
536-
unittest.main()
536+
unittest.main()

0 commit comments

Comments
 (0)