Skip to content

Commit 06657a2

Browse files
Fix linting
1 parent c255340 commit 06657a2

2 files changed

Lines changed: 31 additions & 21 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,15 @@ def __init__(
328328
else:
329329
next_boundary = self._find_next_boundary(current_idx, config.first_num_dense_layers, config.engram_layers)
330330
chunk_name = f"dense_layers_{current_idx}_{next_boundary - 1}"
331-
setattr(self, chunk_name, self._create_scanned_layers(
332-
dense_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
333-
))
331+
setattr(
332+
self,
333+
chunk_name,
334+
self._create_scanned_layers(
335+
dense_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
336+
),
337+
)
334338
current_idx = next_boundary
335-
339+
336340
# 2. Create MoE Chunks (Direct setattr, NO nnx.Dict)
337341
current_idx = config.first_num_dense_layers
338342
while current_idx < config.num_decoder_layers:
@@ -343,9 +347,13 @@ def __init__(
343347
else:
344348
next_boundary = self._find_next_boundary(current_idx, config.num_decoder_layers, config.engram_layers)
345349
chunk_name = f"moe_layers_{current_idx}_{next_boundary - 1}"
346-
setattr(self, chunk_name, self._create_scanned_layers(
347-
moe_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
348-
))
350+
setattr(
351+
self,
352+
chunk_name,
353+
self._create_scanned_layers(
354+
moe_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
355+
),
356+
)
349357
current_idx = next_boundary
350358
else:
351359
# Standard DeepSeek logic when Engrams are disabled
@@ -374,7 +382,7 @@ def __init__(
374382
self.layers_remainder = RemattedGemma3Block(
375383
config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs
376384
) # pytype: disable=wrong-keyword-args
377-
elif self.is_gemma4: # <-- ADDED BLOCK
385+
elif self.is_gemma4: # <-- ADDED BLOCK
378386
attention_pattern_length = len(gemma4.GEMMA4_ATTENTION_PATTERN)
379387
scan_length = config.num_decoder_layers // attention_pattern_length
380388
num_remaining_layers = config.num_decoder_layers % attention_pattern_length
@@ -424,7 +432,7 @@ def __init__(
424432
layer_kwargs = {}
425433
if config.decoder_block == DecoderBlockType.GEMMA3:
426434
layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)}
427-
elif config.decoder_block == DecoderBlockType.GEMMA4: # <-- ADDED
435+
elif config.decoder_block == DecoderBlockType.GEMMA4: # <-- ADDED
428436
layer_kwargs = {"attention_type": gemma4.get_attention_type(layer_id=lyr)}
429437
elif config.decoder_block == DecoderBlockType.LLAMA4:
430438
layer_kwargs = {
@@ -932,16 +940,11 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
932940
def _apply_single_engram_layer(self, y, layer_name, *args, **kwargs):
933941
"""Applies a single, unscanned Engram layer."""
934942
layer = getattr(self, layer_name)
935-
943+
936944
decoder_input_tokens = kwargs.get("decoder_input_tokens")
937945
layer_kwargs = kwargs.get("layer_kwargs", {})
938946

939-
out = layer(
940-
y,
941-
*args,
942-
decoder_input_tokens=decoder_input_tokens,
943-
**layer_kwargs
944-
)
947+
out = layer(y, *args, decoder_input_tokens=decoder_input_tokens, **layer_kwargs)
945948
if isinstance(out, tuple):
946949
y = out[0]
947950
else:
@@ -997,7 +1000,7 @@ def _apply_interleaved_scanned_layers(self, y, layer_prefix, start_idx, end_idx,
9971000
chunk_name = f"{layer_prefix}_{current_idx}_{next_boundary - 1}"
9981001
chunk_stack = getattr(self, chunk_name)
9991002
scan_length = next_boundary - current_idx
1000-
1003+
10011004
y, chunk_stack = self._apply_layers_sequentially(
10021005
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
10031006
)
@@ -1046,7 +1049,7 @@ def __call__(
10461049
# Extract the bidirectional mask locally for layer configurations
10471050
bidirectional_mask = multimodal_input.bidirectional_mask if multimodal_input is not None else None
10481051

1049-
if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4): # <-- UPDATED
1052+
if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4): # <-- UPDATED
10501053
layer_kwargs["bidirectional_mask"] = bidirectional_mask
10511054

10521055
if attention_metadata is not None:
@@ -1071,7 +1074,13 @@ def __call__(
10711074
)
10721075

10731076
y = self._apply_interleaved_scanned_layers(
1074-
y, "moe_layers", cfg.first_num_dense_layers, cfg.num_decoder_layers, cfg.engram_layers, *layer_args, **common_kwargs
1077+
y,
1078+
"moe_layers",
1079+
cfg.first_num_dense_layers,
1080+
cfg.num_decoder_layers,
1081+
cfg.engram_layers,
1082+
*layer_args,
1083+
**common_kwargs,
10751084
)
10761085
else:
10771086
y, self.dense_layers = self._apply_layers_sequentially(
@@ -1123,7 +1132,7 @@ def __call__(
11231132
previous_chunk,
11241133
page_state,
11251134
slot,
1126-
)
1135+
)
11271136
else:
11281137
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
11291138
if scan_length > 0:
@@ -1303,6 +1312,7 @@ def pure_gemma_fn(graphdef, state_in, y_in):
13031312

13041313
return y
13051314

1315+
13061316
def decoder_as_linen(
13071317
config: Config,
13081318
mesh: Mesh,

tests/unit/nnx_decoder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,4 +533,4 @@ def test_different_random_seeds_produce_different_logits(self):
533533

534534

535535
if __name__ == "__main__":
536-
unittest.main()
536+
unittest.main()

0 commit comments

Comments
 (0)