Skip to content

Commit f4674bb

Browse files
Revert fix for fp8
1 parent b7d6dc0 commit f4674bb

2 files changed

Lines changed: 2 additions & 17 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -545,14 +545,8 @@ def pure_layer_fn(state_in, y_in):
545545
out = merged_layer(y_in, **kwargs)
546546
return out, nnx.state(merged_layer)
547547

548-
# Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint
549-
# re-traces and hits UnexpectedTracerError. Skip remat for FP8.
550-
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
551-
if uses_linen_fp8_mutable_state:
552-
out, new_state = pure_layer_fn(state, y)
553-
else:
554-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
555-
out, new_state = checkpointed_fn(state, y)
548+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
549+
out, new_state = checkpointed_fn(state, y)
556550
nnx.update(layer, new_state)
557551

558552
return out

src/maxtext/layers/quantizations.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -812,15 +812,6 @@ def maybe_quantize_model(model, config):
812812
if config.use_qwix_quantization and not config.use_batch_split_schedule:
813813
quantization_provider = get_qt_provider(config)
814814
if quantization_provider:
815-
if config.pure_nnx:
816-
# qwix.quantize_model traces NNX modules to locate quant points, so it
817-
# requires example model inputs (Linen modules are traced lazily and
818-
# take none). Feed dummy decoder tokens/positions of the train shape.
819-
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
820-
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
821-
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
822-
model = qwix.quantize_model(model, quantization_provider, dummy_tokens, dummy_positions)
823-
else:
824815
model = qwix.quantize_model(model, quantization_provider)
825816
return model
826817

0 commit comments

Comments
 (0)