Skip to content

Commit 5cad246

Browse files
Merge pull request #3764 from AI-Hypercomputer:gagik-offline
PiperOrigin-RevId: 907264204
2 parents 8c1a4d2 + 5cf13c3 commit 5cad246

3 files changed

Lines changed: 57 additions & 0 deletions

File tree

src/maxtext/configs/post_train/distillation.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ tokenizer_path: "meta-llama/Llama-3.1-8B"
3838
tokenizer_type: "huggingface"
3939

4040
max_target_length: 2048
41+
packing: True
4142

4243
# --- Training Loop ---
4344
steps: 200000

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,14 @@ def generate_and_save_data(config, local_args):
137137
writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1000")
138138

139139
tokens = batch["inputs"]
140+
# segment_ids prevents cross-document attention under packing; target_tokens/mask are consumed by
141+
# the MTP block when enabled.
140142
logits = teacher_model(
141143
decoder_input_tokens=tokens,
142144
decoder_positions=batch["inputs_position"],
145+
decoder_segment_ids=batch.get("inputs_segmentation"),
146+
decoder_target_tokens=batch.get("targets"),
147+
decoder_target_mask=batch.get("targets_segmentation"),
143148
enable_dropout=False,
144149
)
145150
top_k_vals, top_k_idx = get_top_k_logits(logits, k=k_val)

tests/post_training/unit/distillation_metrics_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333

3434
pytestmark = [pytest.mark.cpu_only, pytest.mark.post_training]
3535

36+
import os
37+
import pickle
38+
import tempfile
3639
import unittest
3740
from typing import List, Optional
3841

@@ -41,6 +44,7 @@
4144
import numpy as np
4245
import optax
4346
from absl.testing import absltest
47+
from array_record.python import array_record_module
4448
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
4549

4650
from maxtext.trainers.post_train.distillation import distillation_utils
@@ -198,6 +202,53 @@ def _run_label_mask_excludes_pad(self, pad_id):
198202
expected = -float(log_p[0, 0, 1]) # target = 1
199203
np.testing.assert_allclose(float(total_loss), expected, rtol=1e-5)
200204

205+
def test_create_labels_masks_packed_segmentation(self):
206+
"""Positions where targets_segmentation == 0 must be zeroed even when the target token is non-pad."""
207+
vocab_size = 8
208+
strategy = _make_strategy(vocab_size, pad_id=0, alpha=0.0)
209+
# Bin layout: doc1 at [0,1], doc2 at [2], in-bin pad at [3]. All targets non-pad.
210+
targets = jnp.array([[1, 2, 3, 1]], dtype=jnp.int32)
211+
targets_segmentation = jnp.array([[1, 1, 2, 0]], dtype=jnp.int32)
212+
213+
labels_packed = strategy.create_labels(targets, targets_segmentation=targets_segmentation)
214+
labels_unpacked = strategy.create_labels(targets)
215+
216+
np.testing.assert_array_equal(np.asarray(labels_packed[0, 3]), np.zeros(vocab_size))
217+
self.assertGreater(float(np.sum(labels_unpacked[0, 3])), 0.0)
218+
for pos in (0, 1, 2):
219+
np.testing.assert_array_equal(np.asarray(labels_packed[0, pos]), np.asarray(labels_unpacked[0, pos]))
220+
221+
def test_offline_iterator_preserves_packing_fields(self):
222+
"""Packed segmentation fields survive write -> ArrayRecord -> OfflineArrayRecordIterator -> Tunix adapter."""
223+
record = {
224+
"tokens": np.array([[10, 11, 12, 13]], dtype=np.int32),
225+
"top_k_logits": np.zeros((1, 4, 8), dtype=np.float32),
226+
"top_k_indices": np.zeros((1, 4, 8), dtype=np.int32),
227+
"inputs_position": np.array([[0, 1, 0, 0]], dtype=np.int32),
228+
"inputs_segmentation": np.array([[1, 1, 2, 0]], dtype=np.int32),
229+
"targets": np.array([[11, 12, 13, 0]], dtype=np.int32),
230+
"targets_segmentation": np.array([[1, 1, 2, 0]], dtype=np.int32),
231+
}
232+
233+
with tempfile.TemporaryDirectory() as tmpdir:
234+
path = os.path.join(tmpdir, "test.array_record")
235+
writer = array_record_module.ArrayRecordWriter(path, "group_size:1")
236+
writer.write(pickle.dumps(record))
237+
writer.close()
238+
239+
it = distillation_utils.OfflineArrayRecordIterator(path, epochs=1)
240+
batch = next(it)
241+
242+
np.testing.assert_array_equal(batch["inputs"], record["tokens"])
243+
np.testing.assert_array_equal(batch["inputs_segmentation"], record["inputs_segmentation"])
244+
np.testing.assert_array_equal(batch["targets_segmentation"], record["targets_segmentation"])
245+
np.testing.assert_array_equal(batch["targets"], record["targets"])
246+
247+
adapter = distillation_utils.MaxTextToTunixIterator(iter([batch]))
248+
tunix_input = next(adapter)
249+
np.testing.assert_array_equal(np.asarray(tunix_input.decoder_segment_ids), record["inputs_segmentation"])
250+
np.testing.assert_array_equal(np.asarray(tunix_input.targets_segmentation), record["targets_segmentation"])
251+
201252
# --- 4. Temperature^2 scaling of soft loss ----------------------------
202253

203254
def test_soft_loss_scales_with_temperature_squared_in_high_T_limit(self):

0 commit comments

Comments
 (0)