Skip to content

Commit 68eb7ce

Browse files
committed
NNX: QK-Clip on NNX + NNX-format checkpoint utilities
Closes the QK-Clip TODO and migrates the remaining checkpoint utilities to NNX, plus convert-on-load and serve-mode paths for AQT in MaxEngine. Linen paths preserved byte-for-byte (every NNX edit is gated on config.pure_nnx or runtime state-shape detection). QK-Clip: - qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via nnx.split -> pure-dict tree_map -> nnx.replace_by_pure_dict -> nnx.update. Accepts both the production NNX intermediate shape (self_attention.attention_op.max_logits) and the synthetic-fixture shape from the existing Linen tests (self_attention.max_logits). - train.py train_step dispatches to apply_qk_clip_nnx for NNX, removing the prior TODO at the QK-Clip call site. Checkpoint utilities (NNX paths added): - standalone_checkpointer.checkpoint_loop builds an NNX init_state_fn under pure_nnx; add_entropy_to_checkpoint dispatches across Linen TrainState, NNX TrainStateNNX Module, and post-split nnx.State shapes. - generate_param_only_checkpoint: NNX init_state_fn under pure_nnx; _possibly_unroll_params_nnx slices scanned NNX layers via dict-style state mutation; _save_decode_checkpoint_nnx writes a bf16 pure-dict tree to orbax. Parallel LoRA decode flow operates on the single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx. - convert_gpt3_ckpt_from_paxml: parallel NNX state_map keystr translation (.params['params']<rest> -> .model<rest>.value, etc.). End-to-end paxml -> NNX conversion is wired but not yet validated on hardware. NNX + AQT in MaxEngine: - model_creation_utils threads quant_mode_str ("train" | "convert" | "serve") through from_config / create_model / get_nnx_create_model_fn / create_nnx_abstract_model / from_pretrained. Default "train" preserves existing behavior; "serve" propagates to configure_quantization so AQT layers don't materialize the full-precision kernel when the on-disk checkpoint already carries qrhs scale factors. - maxengine.__init__ selects the quant mode from config.checkpoint_is_quantized; _load_params_nnx drops the NotImplementedError. Two paths: pre-quantized (checkpoint_is_quantized=True) loads via quant_mode_str="serve"; full-precision + quantization=int8 loads in TRAIN mode and AQT layers quantize per-forward (same numerical result for absmax calibration). - layerwise_quantization._load_and_quantize_nnx: whole-model NNX convert path. Loads full-precision in TRAIN mode, transfers kernels into a CONVERT-mode model, runs forward to populate qrhs.frozen via the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths, saves serve-mode-shaped state. Sharding helpers / from_pretrained QTensor handling: - maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a parallel-tree of replicated NamedSharding leaves when a Variable's value is a composite pytree (AQT serve-mode QTensor with a qvalue int8 leaf and a list of bf16 scale leaves). - model_creation_utils.from_pretrained: dropped a redundant jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode AQT under Flax 0.12.6 (NamedSharding(mesh=AbstractMesh, spec=None) rejected). _build_value_target / _free_device_memory / _unwrap_for_align use Variable.get_value() instead of v[...] indexing for QTensor leaves (QTensor.__getitem__ trips on the LogicallyPartitioned wrapper around qvalue). Also widens the restore filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable type, and skips QTensor leaves in the per-axis shape-alignment dispatch (their saved shape already matches the model). Carries in-progress PR9.5 debugging fixes for the serve-mode reload path (see nnx_migration.md). Tests: - qk_clip_test: 7 new NNX cases covering attention-type guard, MLA wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen<->NNX numeric parity. - standalone_checkpointer_nnx_test (new): 3 cases for adam mu/nu overwrite on TrainStateNNX Module shape, no mutation of state.model params, post-split nnx.State shape from setup_training_state. - generate_param_only_checkpoint_nnx_test (new): 3 cases for scanned layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA delta unroll on the single-nested NNX shape). - layerwise_quantization_nnx_test (new): 3 cases for _strip_kernels_at_quantized_paths covering quantized removal, non-quantized preservation, mixed-shape trees. - maxengine_test: replaced test_quantize_raises_for_nnx with test_quantize_passes_gate_for_nnx; added test_load_pre_quantized_nnx_passes_quant_gate and test_quantized_prefill_nnx_train_mode (real numerical verification with quantization=int8 + random params + TRAIN mode). End-to-end: convert-mode forward + qrhs.frozen extraction + serve-mode-shape save validated on TPU. Reload in serve mode is the remaining piece; details and status in nnx_migration.md PR9.5 entry.
1 parent 10bfe3f commit 68eb7ce

16 files changed

Lines changed: 1544 additions & 218 deletions

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import os
4141
import sys
4242

43+
from flax import nnx
4344
import jax
4445
from jax import random
4546
from jax.sharding import Mesh
@@ -48,11 +49,15 @@
4849
from maxtext.common import checkpointing
4950
from maxtext.common.common_types import MODEL_MODE_TRAIN
5051
from maxtext.layers import quantizations
52+
from maxtext.layers import train_state_nnx
5153
from maxtext.models.models import transformer_as_linen
5254
from maxtext.optimizers import optimizers
5355
from maxtext.utils import max_logging
5456
from maxtext.utils import max_utils
5557
from maxtext.utils import maxtext_utils
58+
from maxtext.utils import maxtext_utils_nnx
59+
from maxtext.utils import model_creation_utils
60+
from maxtext.utils import train_utils
5661
import numpy as np
5762
from psutil import Process
5863
import tensorstore as ts
@@ -87,14 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
8792
devices_array = maxtext_utils.create_device_mesh(cfg)
8893
mesh = Mesh(devices_array, cfg.mesh_axes)
8994

90-
# This conversion script reads paxml-format weights and emits a Linen-format
91-
# MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`,
92-
# `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen
93-
# tree shape). Use the Linen path regardless of pure_nnx.
94-
quant = quantizations.configure_quantization(cfg)
95-
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
96-
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
97-
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
95+
if cfg.pure_nnx:
96+
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
97+
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
98+
_, tx = train_utils.create_training_optimizer(cfg, model)
99+
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
100+
101+
def init_state_fn():
102+
nnx_model = _create_model_partial()
103+
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
104+
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
105+
106+
else:
107+
quant = quantizations.configure_quantization(cfg)
108+
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
109+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
110+
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
111+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
98112

99113
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
100114
cfg.checkpoint_dir,
@@ -103,7 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
103117
cfg.checkpoint_period,
104118
)
105119

106-
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
107120
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
108121
max_logging.log("start")
109122
max_utils.print_mem_stats("After params initialized")
@@ -188,10 +201,21 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
188201
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
189202
}
190203

191-
state_map = {
192-
".step": ("step", None),
193-
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
194-
}
204+
if cfg.pure_nnx:
205+
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
206+
# model params -> ['model']<rest>.value
207+
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
208+
# step -> ['optimizer']['step'].value
209+
# opt count -> ['optimizer']['opt_state']['count'].value
210+
state_map = {
211+
".optimizer.step.value": ("step", None),
212+
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
213+
}
214+
else:
215+
state_map = {
216+
".step": ("step", None),
217+
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
218+
}
195219

196220
def get_layer_prefix(keystr_pax):
197221
# different path format between decoder_layer variable
@@ -203,19 +227,27 @@ def get_layer_prefix(keystr_pax):
203227
return prefix_pax_opt_state
204228

205229
for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
206-
# model variable
207-
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
208230
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
209-
# first momentum in optimizer state
210-
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
211-
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
212-
transform_fn,
213-
)
214-
# second momentum in optimizer state
215-
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
216-
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
217-
transform_fn,
218-
)
231+
if cfg.pure_nnx:
232+
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
233+
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
234+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
235+
transform_fn,
236+
)
237+
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
238+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
239+
transform_fn,
240+
)
241+
else:
242+
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
243+
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
244+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
245+
transform_fn,
246+
)
247+
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
248+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
249+
transform_fn,
250+
)
219251

220252
def verify_fn(key_path, _):
221253
keystr = jax.tree_util.keystr(key_path)
@@ -267,10 +299,11 @@ def map_fn(key_path, value):
267299
max_logging.log("converted state finished")
268300
max_utils.print_mem_stats("converted state finished")
269301

270-
if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
271-
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
302+
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
303+
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
304+
max_logging.log(f"saved a checkpoint at step {step_value}")
272305
# Upon preemption, exit when and only when all ongoing saves are complete.
273-
if checkpoint_manager.reached_preemption(converted_state.step):
306+
if checkpoint_manager.reached_preemption(step_value):
274307
checkpoint_manager.wait_until_finished()
275308
sys.exit()
276309

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,24 @@ def __init__(self, config: Any, devices: Any | None = None):
117117
# Model and Optimizer definition.
118118
quant = quantizations.configure_quantization(config)
119119
if config.pure_nnx:
120+
# `serve` only when the on-disk checkpoint already carries `qrhs.frozen`
121+
# (no full-precision kernel). For `checkpoint_is_quantized=False` with
122+
# quant enabled we stay in `train` mode and let AQT quantize per-forward
123+
# against the full-precision kernel — same numerical result as `serve`
124+
# for absmax calibration, just slower.
125+
nnx_quant_mode_str = "serve" if (quant is not None and config.checkpoint_is_quantized) else "train"
120126
# We need both PREFILL and AR abstract models because the cache vars inherit
121127
# CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and
122128
# bulk_insert searches for the substring "cache_batch" in the AR-mode names.
123129
# Calling nnx.eval_shape directly (instead of create_nnx_abstract_model) avoids
124130
# the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm".
125-
_create_model = model_creation_utils.get_nnx_create_model_fn(config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL)
131+
_create_model = model_creation_utils.get_nnx_create_model_fn(
132+
config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL, quant_mode_str=nnx_quant_mode_str
133+
)
126134
_create_model_ar = model_creation_utils.get_nnx_create_model_fn(
127-
config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE
135+
config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE, quant_mode_str=nnx_quant_mode_str
128136
)
137+
self._nnx_quant_mode_str = nnx_quant_mode_str
129138
with nn_partitioning.axis_rules(config.logical_axis_rules):
130139
abstract_model = nnx.eval_shape(_create_model)
131140
abstract_model_ar = nnx.eval_shape(_create_model_ar)
@@ -371,9 +380,15 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
371380
return params
372381

373382
def _load_params_nnx(self, params, rng):
374-
"""NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings."""
375-
if self.model.quant is not None:
376-
raise NotImplementedError("pure_nnx + quantization not yet supported. Use pure_nnx=False.")
383+
"""NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings.
384+
385+
Quantization handling:
386+
* `checkpoint_is_quantized=True`: model built in `serve` mode (no full
387+
kernel), `from_pretrained` reads `qrhs.frozen` from disk.
388+
* `checkpoint_is_quantized=False` + `quantization=...`: model built in
389+
`train` mode, full-precision kernel loaded; AQT layers quantize per
390+
forward. Same output as serve mode (absmax calibration), slower.
391+
"""
377392

378393
if params:
379394
print("Resharding given NNX params")
@@ -396,13 +411,44 @@ def _load_params_nnx(self, params, rng):
396411
max_logging.log("Loading NNX params via from_pretrained")
397412
with self._mesh:
398413
nnx_model = model_creation_utils.from_pretrained(
399-
self.config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE
414+
self.config,
415+
mesh=self._mesh,
416+
model_mode=MODEL_MODE_AUTOREGRESSIVE,
417+
quant_mode_str=self._nnx_quant_mode_str,
400418
)
401-
# Refresh graphdef from the concrete loaded model so subsequent merges line up.
402-
graphdef, params_state, _, rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...)
419+
# 4-way split keeps the loaded AQT `qrhs.frozen` leaves (and any other
420+
# non-Param/non-Cache vars) in `loaded_rest_state` so they survive into
421+
# `_nnx_rest_state`. Param-only filtering would silently drop them and
422+
# the model would run with random qrhs values.
423+
_, params_state, _, loaded_rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...)
424+
# `_prefill_jit` re-merges with `self.graphdef`, which must be the PREFILL
425+
# graphdef built in `__init__` (matching `_create_model_fn`). Don't
426+
# overwrite with the AR-mode graphdef from `from_pretrained` — the
427+
# PREFILL/AR attention ops have different cache variable shapes, and a
428+
# mismatch trips the `assert prefill_kv_cache` check inside attention_op.
429+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
430+
concrete_model = self._create_model_fn()
431+
graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...)
432+
# Overlay loaded non-Param/non-Cache leaves (e.g. AQT qrhs.frozen) onto
433+
# the PREFILL-mode rest_state. The PREFILL concrete_model already has
434+
# placeholder qrhs vars at the right paths; we just swap in the loaded
435+
# values. Anything only in `loaded_rest_state` (e.g. AR-only RNG slots)
436+
# is ignored. We keep PREFILL rest_state as the base so RNG variables
437+
# match the PREFILL graphdef's expectations.
438+
loaded_rest_dict = loaded_rest_state.to_pure_dict()
439+
rest_dict = rest_state.to_pure_dict()
440+
def _overlay(dst, src):
441+
if isinstance(dst, dict):
442+
for k, v in dst.items():
443+
if k in src:
444+
dst[k] = _overlay(v, src[k])
445+
return dst
446+
return src if not isinstance(src, dict) else dst
447+
rest_dict = _overlay(rest_dict, loaded_rest_dict)
448+
nnx.replace_by_pure_dict(rest_state, rest_dict)
403449
self.graphdef = graphdef
404450
self._nnx_rest_state = rest_state
405-
del nnx_model
451+
del nnx_model, concrete_model
406452

407453
self.abstract_params = jax.tree.map(
408454
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
@@ -485,7 +531,16 @@ def quantize_params(self, state, rng: PRNGKeyType | None = None):
485531
if rng is None:
486532
rng = jax.random.PRNGKey(0)
487533
if self.config.pure_nnx:
488-
raise NotImplementedError("pure_nnx + quantize_params not yet supported.")
534+
# NNX takes a different code path: convert-on-load lives in `_load_params_nnx`
535+
# via `_convert_and_quantize_nnx`, which runs the dummy forward against a
536+
# CONVERT-mode model and transfers `qrhs.frozen` into the SERVE model.
537+
# The standalone `quantize_params(state, rng)` API expects a Linen-shape
538+
# `state.params` dict and isn't reachable on the NNX pathway in maxengine
539+
# (load_params already dispatched to _load_params_nnx).
540+
raise NotImplementedError(
541+
"Use load_params() on NNX — the convert step runs inside _load_params_nnx via "
542+
"_convert_and_quantize_nnx. quantize_params(state, rng) is the Linen API."
543+
)
489544

490545
self.model.quant.quant_mode = quantizations.get_quant_mode("convert")
491546

src/maxtext/models/gpt3.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from flax import nnx
2929

3030
from maxtext.common.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN
31+
from maxtext.inference import kvcache
3132
from maxtext.layers import initializers, nnx_wrappers
3233
from maxtext.layers.linears import DenseGeneral, MlpBlock, canonicalize_tuple, normalize_axes
3334
from maxtext.layers import quantizations
@@ -235,6 +236,7 @@ def __init__(
235236
self.key_axis_names = key_axis_names
236237
self.value_axis_names = value_axis_names
237238
self.out_axis_names = out_axis_names
239+
self.model_mode = model_mode
238240
self.rngs = rngs
239241
if self.fused_qkv:
240242
self.qkv_proj = self.create_projection_layer(
@@ -252,6 +254,7 @@ def __init__(
252254
mesh=self.mesh,
253255
attention_kernel=self.attention_kernel,
254256
max_target_length=self.max_target_length,
257+
max_prefill_predict_length=self.max_prefill_predict_length,
255258
float32_qk_product=self.float32_qk_product,
256259
float32_logits=self.float32_logits,
257260
quant=self.quant,
@@ -260,6 +263,30 @@ def __init__(
260263
num_kv_heads=self.num_heads,
261264
dtype=self.dtype,
262265
)
266+
# KV cache only matters in non-TRAIN modes. Mirrors Attention.__init__ in
267+
# attentions.py so prefill / autoregressive get a real KVCache_0 module
268+
# whose update_kv_caches() builds the cached_values tuple that
269+
# AttentionOp.__call__ requires.
270+
batch_size, _ = max_utils.get_batch_seq_len_for_mode(config, model_mode)
271+
self.KVCache_0 = (
272+
kvcache.KVCache(
273+
max_prefill_length=self.max_prefill_predict_length,
274+
max_target_length=self.max_target_length,
275+
batch=batch_size,
276+
key_seq_len=1,
277+
value_seq_len=1,
278+
key_heads=self.num_heads,
279+
value_heads=self.num_heads,
280+
key_head_size=self.head_dim,
281+
value_head_size=self.head_dim,
282+
dtype=self.dtype,
283+
kv_quant=self.kv_quant,
284+
model_mode=model_mode,
285+
rngs=self.rngs,
286+
)
287+
if model_mode != MODEL_MODE_TRAIN
288+
else None
289+
)
263290

264291
def create_projection_layer(
265292
self,
@@ -328,7 +355,18 @@ def __call__(
328355
value = nn.with_logical_constraint(value, self.value_axis_names)
329356
value = checkpoint_name(value, "value_proj")
330357

331-
out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode)
358+
cached_values = [None, None]
359+
if model_mode != MODEL_MODE_TRAIN and self.KVCache_0 is not None:
360+
prefill_kv_cache, ar_kv_cache = self.KVCache_0(
361+
key=key,
362+
value=value,
363+
decoder_segment_ids=decoder_segment_ids,
364+
model_mode=model_mode,
365+
use_ragged_attention=False,
366+
previous_chunk=None,
367+
)
368+
cached_values = [prefill_kv_cache, ar_kv_cache]
369+
out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode, cached_values)
332370

333371
out = nn.with_logical_constraint(out, self.out_axis_names)
334372

src/maxtext/trainers/pre_train/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,11 +526,11 @@ def move(path, value):
526526
"learning/total_weights": total_weights,
527527
}
528528
if config.use_qk_clip:
529-
# Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX)
530529
if isinstance(model, nn.Module):
531530
new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
531+
else:
532+
new_state = qk_clip_utils.apply_qk_clip_nnx(new_state, intermediate_outputs, config)
532533

533-
# Report max_logits metric
534534
global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs)
535535
if global_max_logit is not None:
536536
scalar_metrics["learning/max_logits"] = global_max_logit

0 commit comments

Comments
 (0)