Skip to content

Commit 4de8440

Browse files
committed
Updated code and cleaned up commits
1 parent f3d9f5c commit 4de8440

2 files changed

Lines changed: 144 additions & 27 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
from typing import Any, Literal, NewType, Optional
2828

2929
import jax
30-
from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode
31-
from maxtext.utils import gcs_utils
32-
from maxtext.utils import max_utils
3330
from MaxText import accelerator_to_spec_map
31+
from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode
3432
from MaxText.globals import MAXTEXT_ASSETS_ROOT
33+
from maxtext.utils import gcs_utils
34+
from maxtext.utils import max_utils
3535
from pydantic.config import ConfigDict
3636
from pydantic.fields import Field
3737
from pydantic.functional_validators import field_validator, model_validator
@@ -497,8 +497,6 @@ class Attention(BaseModel):
497497
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")
498498
use_jax_splash: bool = Field(False, description="Whether to use jax splash attention.")
499499
force_q_layout: bool = Field(False, description="Force the Q layout")
500-
use_qk_clip: bool = Field(False, description="Whether to use QK-Clip (MuonClip) for training stability.")
501-
qk_clip_threshold: float = Field(100.0, description="Threshold for QK-Clip (tau).")
502500

503501

504502
class MoBa(BaseModel):
@@ -1055,6 +1053,12 @@ class Distillation(BaseModel):
10551053
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
10561054
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
10571055

1056+
# --- Teacher topk distillation ---
1057+
teacher_logits_optional_keys: list[str] = Field(
1058+
default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"],
1059+
description="Optional keys to save from teacher logits",
1060+
)
1061+
10581062

10591063
class TrainingLoop(BaseModel):
10601064
"""Configuration for the main training loop, evaluation, and reproducibility."""
@@ -1560,9 +1564,6 @@ class VLLM(BaseModel):
15601564
max_num_batched_tokens: Optional[int] = Field(None, description="Max number of batched tokens in vLLM.")
15611565
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
15621566
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
1563-
vllm_hf_overrides: dict[str, Any] = Field(
1564-
default_factory=dict, description="Overrides for HuggingFace model config for MaxText model."
1565-
)
15661567
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
15671568

15681569

@@ -1933,13 +1934,6 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
19331934
if self.steps == -1:
19341935
self.steps = self.learning_rate_schedule_steps
19351936

1936-
# Validate deepstack + scan_layers incompatibility
1937-
if self.deepstack_visual_indexes_for_vit and self.scan_layers:
1938-
raise ValueError(
1939-
"Deepstack visual embedding injection requires scan_layers=False. "
1940-
"Set scan_layers=False in your config to use deepstack features."
1941-
)
1942-
19431937
# Validate WSD learning rate schedule fractions
19441938
if self.lr_schedule_type == LearningRateScheduleType.WSD:
19451939
total_fraction = self.warmup_steps_fraction + self.wsd_decay_steps_fraction
@@ -2412,18 +2406,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24122406
if self.force_q_layout and not self.use_jax_splash:
24132407
raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.")
24142408

2415-
if self.use_qk_clip and self.attention_type != "mla":
2416-
raise ValueError(
2417-
f"QK-Clip is only supported when attention_type='mla', but found attention_type='{self.attention_type}'."
2418-
)
2419-
2420-
if self.use_qk_clip and self.attn_logits_soft_cap is not None:
2421-
raise ValueError(
2422-
"QK-Clip monitors raw dot products, but attn_logits_soft_cap is enabled. "
2423-
"Recording pre-cap max_logits is not fully supported yet. "
2424-
"Please disable attn_logits_soft_cap when using use_qk_clip."
2425-
)
2426-
24272409
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
24282410
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
24292411
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
This module provides functionality to save top-k teacher logits
3+
for distillation purposes in MaxText.
4+
5+
Example command: python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py src/maxtext/configs/post_train/distillation.yml --top_k=128
6+
"""
7+
8+
import os
9+
import pickle
10+
from typing import Sequence
11+
import argparse
12+
import time
13+
import sys
14+
import tensorflow as tf
15+
16+
import jax
17+
import numpy as np
18+
import functools
19+
from itertools import islice
20+
21+
from absl import app
22+
from MaxText import pyconfig
23+
from maxtext.utils import model_creation_utils
24+
from maxtext.input_pipeline import input_pipeline_interface
25+
from maxtext.utils import maxtext_utils
26+
from maxtext.utils import max_logging
27+
28+
from jax.experimental import multihost_utils
29+
from array_record.python import array_record_module
30+
31+
32+
def get_top_k_logits(logits: jax.Array, k: int):
33+
"""Extracts the top-k values and their vocabulary indices"""
34+
top_k_values, top_k_indices = jax.lax.top_k(logits, k)
35+
return top_k_values, top_k_indices
36+
37+
38+
def get_local_cpu_array(arr):
39+
"""Extracts the local data from a sharded JAX array to a host numpy array."""
40+
return np.concatenate([np.array(s.data) for s in arr.addressable_shards], axis=0)
41+
42+
43+
def generate_and_save_data(config, k_val):
44+
"""Generates top-k logits from the teacher model and saves them to an ArrayRecord file"""
45+
devices = jax.devices()
46+
devices_array = maxtext_utils.create_device_mesh(config, devices)
47+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
48+
49+
# Loading teacher model and dataset iterator
50+
max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...")
51+
teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
52+
train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)
53+
54+
output_dir = config.base_output_directory
55+
if config.run_name:
56+
output_dir = os.path.join(output_dir, config.run_name)
57+
58+
if jax.process_index() == 0:
59+
if not tf.io.gfile.exists(output_dir):
60+
tf.io.gfile.makedirs(output_dir)
61+
62+
# Sync all hosts to ensure directory exists before writers open files
63+
multihost_utils.sync_global_devices("create_output_dir")
64+
65+
# Each host writes to a unique file based on its process index to avoid write conflicts
66+
filename = f"teacher_top_k_process_{jax.process_index()}.array_record"
67+
output_path = os.path.join(output_dir, filename)
68+
69+
max_logging.log(f"Process {jax.process_index()} writing directly to: {output_path}")
70+
writer = array_record_module.ArrayRecordWriter(output_path, "group_size:1000")
71+
72+
max_logging.log(f"Starting Top-K generation loop for {config.steps} steps...")
73+
loop_start = time.time()
74+
for step, batch in enumerate(islice(train_iter, config.steps)):
75+
step_start = time.time()
76+
tokens = batch["inputs"]
77+
logits = teacher_model(
78+
decoder_input_tokens=tokens,
79+
decoder_positions=batch["inputs_position"],
80+
enable_dropout=False,
81+
)
82+
top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val)
83+
84+
# Extract only the local data for this host (Distributed Writing)
85+
local_vals = get_local_cpu_array(top_k_vals)
86+
local_idx = get_local_cpu_array(top_k_idx)
87+
local_tokens = get_local_cpu_array(tokens)
88+
89+
optional_keys = config.teacher_logits_optional_keys
90+
local_optionals = {
91+
key: get_local_cpu_array(batch[key]) for key in optional_keys if key in batch
92+
}
93+
94+
record_dict = {
95+
"tokens": local_tokens,
96+
"top_k_logits": local_vals,
97+
"top_k_indices": local_idx,
98+
}
99+
for key, local_val in local_optionals.items():
100+
record_dict[key] = local_val
101+
102+
writer.write(pickle.dumps(record_dict))
103+
104+
if step % 50 == 0:
105+
max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s")
106+
107+
max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s")
108+
109+
writer.close()
110+
max_logging.log(f"Finished writing to {output_path}.")
111+
112+
113+
def main(argv: Sequence[str], local_args):
114+
# Initialize the global configuration
115+
global_config = pyconfig.initialize(argv)
116+
teacher_overrides = global_config.teacher_overrides
117+
teacher_argv = [argv[0], argv[1]]
118+
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
119+
120+
generate_and_save_data(teacher_config, local_args.top_k)
121+
122+
123+
if __name__ == "__main__":
124+
parser = argparse.ArgumentParser()
125+
parser.add_argument(
126+
"--top_k",
127+
type=int,
128+
required=False,
129+
default=128,
130+
help="Top K value for logits.",
131+
)
132+
local_arg, remaining_args = parser.parse_known_args()
133+
134+
main_wrapper = functools.partial(main, local_args=local_arg)
135+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)