|
33 | 33 |
|
34 | 34 | pytestmark = [pytest.mark.cpu_only, pytest.mark.post_training] |
35 | 35 |
|
| 36 | +import os |
| 37 | +import pickle |
| 38 | +import tempfile |
36 | 39 | import unittest |
37 | 40 | from typing import List, Optional |
38 | 41 |
|
|
41 | 44 | import numpy as np |
42 | 45 | import optax |
43 | 46 | from absl.testing import absltest |
| 47 | +from array_record.python import array_record_module |
44 | 48 | from jax.sharding import Mesh, NamedSharding, PartitionSpec as P |
45 | 49 |
|
46 | 50 | from maxtext.trainers.post_train.distillation import distillation_utils |
@@ -198,6 +202,53 @@ def _run_label_mask_excludes_pad(self, pad_id): |
198 | 202 | expected = -float(log_p[0, 0, 1]) # target = 1 |
199 | 203 | np.testing.assert_allclose(float(total_loss), expected, rtol=1e-5) |
200 | 204 |
|
| 205 | + def test_create_labels_masks_packed_segmentation(self): |
| 206 | + """Positions where targets_segmentation == 0 must be zeroed even when the target token is non-pad.""" |
| 207 | + vocab_size = 8 |
| 208 | + strategy = _make_strategy(vocab_size, pad_id=0, alpha=0.0) |
| 209 | + # Bin layout: doc1 at [0,1], doc2 at [2], in-bin pad at [3]. All targets non-pad. |
| 210 | + targets = jnp.array([[1, 2, 3, 1]], dtype=jnp.int32) |
| 211 | + targets_segmentation = jnp.array([[1, 1, 2, 0]], dtype=jnp.int32) |
| 212 | + |
| 213 | + labels_packed = strategy.create_labels(targets, targets_segmentation=targets_segmentation) |
| 214 | + labels_unpacked = strategy.create_labels(targets) |
| 215 | + |
| 216 | + np.testing.assert_array_equal(np.asarray(labels_packed[0, 3]), np.zeros(vocab_size)) |
| 217 | + self.assertGreater(float(np.sum(labels_unpacked[0, 3])), 0.0) |
| 218 | + for pos in (0, 1, 2): |
| 219 | + np.testing.assert_array_equal(np.asarray(labels_packed[0, pos]), np.asarray(labels_unpacked[0, pos])) |
| 220 | + |
| 221 | + def test_offline_iterator_preserves_packing_fields(self): |
| 222 | + """Packed segmentation fields survive write -> ArrayRecord -> OfflineArrayRecordIterator -> Tunix adapter.""" |
| 223 | + record = { |
| 224 | + "tokens": np.array([[10, 11, 12, 13]], dtype=np.int32), |
| 225 | + "top_k_logits": np.zeros((1, 4, 8), dtype=np.float32), |
| 226 | + "top_k_indices": np.zeros((1, 4, 8), dtype=np.int32), |
| 227 | + "inputs_position": np.array([[0, 1, 0, 0]], dtype=np.int32), |
| 228 | + "inputs_segmentation": np.array([[1, 1, 2, 0]], dtype=np.int32), |
| 229 | + "targets": np.array([[11, 12, 13, 0]], dtype=np.int32), |
| 230 | + "targets_segmentation": np.array([[1, 1, 2, 0]], dtype=np.int32), |
| 231 | + } |
| 232 | + |
| 233 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 234 | + path = os.path.join(tmpdir, "test.array_record") |
| 235 | + writer = array_record_module.ArrayRecordWriter(path, "group_size:1") |
| 236 | + writer.write(pickle.dumps(record)) |
| 237 | + writer.close() |
| 238 | + |
| 239 | + it = distillation_utils.OfflineArrayRecordIterator(path, epochs=1) |
| 240 | + batch = next(it) |
| 241 | + |
| 242 | + np.testing.assert_array_equal(batch["inputs"], record["tokens"]) |
| 243 | + np.testing.assert_array_equal(batch["inputs_segmentation"], record["inputs_segmentation"]) |
| 244 | + np.testing.assert_array_equal(batch["targets_segmentation"], record["targets_segmentation"]) |
| 245 | + np.testing.assert_array_equal(batch["targets"], record["targets"]) |
| 246 | + |
| 247 | + adapter = distillation_utils.MaxTextToTunixIterator(iter([batch])) |
| 248 | + tunix_input = next(adapter) |
| 249 | + np.testing.assert_array_equal(np.asarray(tunix_input.decoder_segment_ids), record["inputs_segmentation"]) |
| 250 | + np.testing.assert_array_equal(np.asarray(tunix_input.targets_segmentation), record["targets_segmentation"]) |
| 251 | + |
201 | 252 | # --- 4. Temperature^2 scaling of soft loss ---------------------------- |
202 | 253 |
|
203 | 254 | def test_soft_loss_scales_with_temperature_squared_in_high_T_limit(self): |
|
0 commit comments