diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index aba1000056..0b85ccc94f 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -131,12 +131,10 @@ jobs: else SPLIT_ARGS="" fi - # TODO: Fix the skipped tests and remove the deselect flags .venv/bin/python3 -m pytest ${INPUTS_PYTEST_ADDOPTS} \ -v \ -m "${FINAL_PYTEST_MARKER}" \ --durations=0 \ - --deselect "tests/unit/tokenizer_test.py::TokenizerTest::test_detokenize" \ --cov=MaxText \ --cov=maxtext \ --cov-report=xml \ diff --git a/src/maxtext/assets/tokenizers/tokenizer.default b/src/maxtext/assets/tokenizers/tokenizer.default index 65d7ffeada..c4fd951502 100644 Binary files a/src/maxtext/assets/tokenizers/tokenizer.default and b/src/maxtext/assets/tokenizers/tokenizer.default differ diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e66ecac8fa..3df51ac106 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2531,6 +2531,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.use_grpo = True else: self.use_grpo = False + + if self.use_batch_split_schedule: + if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm): + raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`") + if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"): + raise ValueError( + "Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`" + ) + if self.opt_type == "muon" and self.decoder_block not in [ DecoderBlockType.DEEPSEEK, DecoderBlockType.QWEN3, @@ -2539,7 +2548,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ]: raise ValueError( "Muon dimension numbers haven't been tested for this model. Run this command first: " - f"`python3 -m MaxText.muon_utils {self.model_name} True`" + f"`python3 -m maxtext.utils.muon_utils {self.model_name} True`" ) if self.force_q_layout and not self.use_jax_splash: raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.") diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index abc47529b7..503e7e0b04 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -17,7 +17,7 @@ import functools import json import re -from typing import Tuple, Sequence +from typing import Tuple, Sequence, Callable from dataclasses import dataclass from aqt.jax.v2 import config as aqt_config @@ -27,6 +27,7 @@ from aqt.jax.v2 import calibration import qwix +from qwix._src.core import dot_general_qt import jax import jax.numpy as jnp @@ -194,6 +195,88 @@ def einsum(self, mesh_axes: Tuple[str, ...] = ()): return aqt_einsum +@dataclass +class QwixQuantization: + """Configures Qwix quantization github.com/google/qwix, for training only.""" + + quant_mode = "train" # needed by external call + act_calibration_method: str = "absmax" + weight_calibration_method: str = "absmax" + bwd_calibration_method: str = "absmax" + + def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig: + """Returns Qwix dot_general config for fp8_full quantization.""" + return dot_general_qt.DotGeneralQtConfig( + lhs_qtype=jnp.float8_e4m3fn, # activation + rhs_qtype=jnp.float8_e4m3fn, # weight + dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient + drhs_grad_qtype=jnp.float8_e5m2, # weight gradient + lhs_calibration_method=self.act_calibration_method, + rhs_calibration_method=self.weight_calibration_method, + dlhs_grad_calibration_method=self.bwd_calibration_method, + drhs_grad_calibration_method=self.bwd_calibration_method, + tile_size=None, + ) + + def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): + """Returns Qwix dot_general.""" + return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config()) + + def einsum(self, mesh_axes: Tuple[str, ...] = ()): + """Returns Qwix einsum.""" + return QwixEinsum(config=self._get_fp8_full_qwix_config()) + + +class QwixDotGeneral(nn.Module): + """A callable class for Qwix dot_general.""" + + config: dot_general_qt.DotGeneralQtConfig + + @nn.compact + def __call__( + self, + lhs: jax.Array, + rhs: jax.Array, + dimension_numbers: jax.lax.DotDimensionNumbers, + precision: jax.lax.PrecisionLike = None, + preferred_element_type: jax.typing.DTypeLike | None = None, + *, + out_sharding=None, + ) -> jax.Array: + + return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config) + + +class QwixEinsum(nn.Module): + """A callable class for Qwix einsum.""" + + config: dot_general_qt.DotGeneralQtConfig + + @nn.compact + def __call__( + self, + einsum_str: str, + *operands: jax.Array, + precision: jax.lax.PrecisionLike = None, + preferred_element_type: jax.typing.DTypeLike | None = None, + _dot_general: Callable[..., jax.Array] | None = None, + out_sharding=None, + ) -> jax.Array: + + def custom_dot_general(*args, **kwargs): + return dot_general_qt.dot_general_qt(*args[:3], self.config) + + with jax.disable_jit(): + return jnp.einsum( + einsum_str, + *operands, + precision=precision, + preferred_element_type=preferred_element_type, + _dot_general=custom_dot_general, + out_sharding=out_sharding, + ) + + @dataclass class Fp8Quantization(Quantization): """Configures Fp8 quantization for NVIDIA GPUs""" @@ -539,13 +622,20 @@ def get_quant_mode(quant_mode_str: str = "train"): return aqt_flax.QuantMode.SERVE elif quant_mode_str == "convert": return aqt_flax.QuantMode.CONVERT - else: - raise ValueError(f"Invalid quantization mode {quant_mode_str}.") - return None + raise ValueError(f"Invalid quantization mode {quant_mode_str}.") def configure_quantization(config: Config, quant_mode_str: str = "train"): """Configure quantization based on user config and quant mode.""" + if config.use_batch_split_schedule and config.quantization: + if not (config.use_qwix_quantization and config.quantization == "fp8_full"): + raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`") + return QwixQuantization( + weight_calibration_method=config.weight_quantization_calibration_method, + act_calibration_method=config.act_quantization_calibration_method, + bwd_calibration_method=config.bwd_quantization_calibration_method, + ) + if config.use_qwix_quantization: return None quant_cfg = _get_quant_config(config) @@ -726,7 +816,8 @@ def get_qt_provider(config): def maybe_quantize_model(model, config): """Quantize the model if quantization is enabled.""" - if config.use_qwix_quantization: + # Batch split is not using Qwix's interception feature but manual plumbing + if config.use_qwix_quantization and not config.use_batch_split_schedule: quantization_provider = get_qt_provider(config) if quantization_provider: model = qwix.quantize_model(model, quantization_provider) diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index fabf0bacad..24ccf1c7b5 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -168,6 +168,126 @@ def merge(x, split_factor=2): return jnp.reshape(x, (-1,) + x.shape[2:]) +def gather_weights(weights, mesh): + """all-gathers FSDP sharded weights.""" + + def fn(weights): + ( + (pre_attn_norm, post_attn_norm), + (wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out), + ), ( + (gate, bias), + (routed_wi_0, routed_wi_1, routed_wo), + (shared_wi_0, shared_wi_1, shared_wo), + ) = weights + # All-gather across FSDP axis. Expert axis is used for FSDP in attention. + wq_a = jax.lax.all_gather(wq_a, axis_name="expert", tiled=True, axis=1) + wq_a = jax.lax.all_gather(wq_a, axis_name="fsdp", tiled=True) + wq_b = jax.lax.all_gather(wq_b, axis_name="expert", tiled=True, axis=1) + wq_b = jax.lax.all_gather(wq_b, axis_name="fsdp", tiled=True) + wkv_a = jax.lax.all_gather(wkv_a, axis_name="expert", tiled=True, axis=1) + wkv_a = jax.lax.all_gather(wkv_a, axis_name="fsdp", tiled=True) + wkv_b = jax.lax.all_gather(wkv_b, axis_name="expert", tiled=True, axis=1) + wkv_b = jax.lax.all_gather(wkv_b, axis_name="fsdp", tiled=True) + out = jax.lax.all_gather(out, axis_name="expert", tiled=True) + out = jax.lax.all_gather(out, axis_name="fsdp", tiled=True, axis=2) + gate = jax.lax.all_gather(gate, axis_name="fsdp", tiled=True) + routed_wi_0 = jax.lax.all_gather(routed_wi_0, axis_name="fsdp", tiled=True) + routed_wi_1 = jax.lax.all_gather(routed_wi_1, axis_name="fsdp", tiled=True) + routed_wo = jax.lax.all_gather(routed_wo, axis_name="fsdp", tiled=True) + shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="expert", tiled=True, axis=1) + shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="fsdp", tiled=True) + shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="expert", tiled=True, axis=1) + shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="fsdp", tiled=True) + shared_wo = jax.lax.all_gather(shared_wo, axis_name="expert", tiled=True) + shared_wo = jax.lax.all_gather(shared_wo, axis_name="fsdp", tiled=True, axis=1) + return ( + ( + (pre_attn_norm, post_attn_norm), + (wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out), + ), + ( + (gate, bias), + (routed_wi_0, routed_wi_1, routed_wo), + (shared_wi_0, shared_wi_1, shared_wo), + ), + ) + + return jax.shard_map( + fn, + mesh=mesh, + in_specs=( + ( + ( + ( + jax.sharding.PartitionSpec(None), + jax.sharding.PartitionSpec(None), + ), + ( + jax.sharding.PartitionSpec("fsdp", "expert"), + jax.sharding.PartitionSpec("fsdp", "expert", None), + jax.sharding.PartitionSpec(None), + jax.sharding.PartitionSpec("fsdp", "expert"), + jax.sharding.PartitionSpec("fsdp", "expert", None), + jax.sharding.PartitionSpec(None), + jax.sharding.PartitionSpec("expert", None, "fsdp"), + ), + ), + ( + ( + jax.sharding.PartitionSpec("fsdp", None), + jax.sharding.PartitionSpec(None), + ), + ( + jax.sharding.PartitionSpec("fsdp", None, "expert"), + jax.sharding.PartitionSpec("fsdp", None, "expert"), + jax.sharding.PartitionSpec("fsdp", "expert", None), + ), + ( + jax.sharding.PartitionSpec("fsdp", "expert"), + jax.sharding.PartitionSpec("fsdp", "expert"), + jax.sharding.PartitionSpec("expert", "fsdp"), + ), + ), + ), + ), + out_specs=( + ( + ( + jax.sharding.PartitionSpec(None), + jax.sharding.PartitionSpec(None), + ), + ( + jax.sharding.PartitionSpec(None, None), + jax.sharding.PartitionSpec(None, None, None), + jax.sharding.PartitionSpec(None), + jax.sharding.PartitionSpec(None, None), + jax.sharding.PartitionSpec(None, None, None), + jax.sharding.PartitionSpec(None), + jax.sharding.PartitionSpec(None, None, None), + ), + ), + ( + ( + jax.sharding.PartitionSpec(None, None), + jax.sharding.PartitionSpec(None), + ), + ( + jax.sharding.PartitionSpec(None, None, "expert"), + jax.sharding.PartitionSpec(None, None, "expert"), + jax.sharding.PartitionSpec(None, "expert", None), + ), + ( + jax.sharding.PartitionSpec(None, None), + jax.sharding.PartitionSpec(None, None), + jax.sharding.PartitionSpec(None, None), + ), + ), + ), + check_vma=False, + )(weights) + + def scan_batch_split_layers( inputs, params, @@ -183,6 +303,7 @@ def scan_batch_split_layers( """Scans the layers with batch-split schedule.""" def batch_split_scan_fn(inputs, weights, dpos, dseg): + weights = gather_weights(weights, mesh) xs = batch_split_schedule( inputs, weights, @@ -285,6 +406,7 @@ def batch_split_schedule( rope_factor=cfg.rope_factor, mscale=cfg.mscale, dtype=cfg.dtype, + quant=quant, ) xs = moe( @@ -297,6 +419,7 @@ def batch_split_schedule( expert_axis_name="expert", use_gather_mosaic_kernel=False, config=cfg, + quant=quant, ) return xs @@ -319,7 +442,21 @@ def with_data_parallel_constraint(x, mesh): return jax.lax.with_sharding_constraint(x, jax.NamedSharding(mesh, activation_pspec)) -def dot(x, y, axes=1): +def dot(x, y, quant=None, axes=1): + """Computes the dot product of two arrays, optionally using quantization.""" + if quant is not None: + # Convert axes to jax.lax.dot_general dimension_numbers + if isinstance(axes, int): + x_contract = tuple(range(x.ndim - axes, x.ndim)) + y_contract = tuple(range(axes)) + else: + x_contract, y_contract = axes + dimension_numbers = ((x_contract, y_contract), ((), ())) + # Instantiate and call qwix dot_general + custom_dot = quant.dot_general_cls()() + return custom_dot(lhs=x, rhs=y, dimension_numbers=dimension_numbers) + + # Unquantized return jnp.tensordot(x, y, axes=axes) @@ -345,6 +482,7 @@ def mla_with_norms( rope_factor, mscale, dtype, + quant, ): """Performs MLA with pre- and post-normalization.""" (pre_attn_scale, post_attn_scale), attn_ws = weights @@ -379,6 +517,7 @@ def fn(args): dtype=dtype, mscale=mscale, attention_op_fn=attn_op, + quant=quant, ), mesh, ) @@ -414,6 +553,7 @@ def mla( mscale, attention_op_fn, dtype, + quant, ): """Performs MLA.""" ( @@ -442,6 +582,7 @@ def mla( dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, mscale=mscale, + quant=quant, ) query = jax.ad_checkpoint.checkpoint_name(query, "query_proj") key, value = kv_projection( @@ -462,6 +603,7 @@ def mla( dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, + quant=quant, ) key = jax.ad_checkpoint.checkpoint_name(key, "key_proj") value = jax.ad_checkpoint.checkpoint_name(value, "value_proj") @@ -474,7 +616,7 @@ def mla( cached_values=[None, None], ) out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") - out = dot(out, out_weights, axes=2) + out = dot(out, out_weights, quant=quant, axes=2) out = jax.ad_checkpoint.checkpoint_name(out, "out_proj") return out @@ -497,6 +639,7 @@ def query_projection( rope_factor, dtype, mscale, + quant, ): """Performs query projection.""" # Set softmax scaling. @@ -507,7 +650,7 @@ def query_projection( softmax_scale = softmax_scale * m * m # LoRA path - low_rank_q = dot(inputs_q, wq_a_weights) + low_rank_q = dot(inputs_q, wq_a_weights, quant=quant) low_rank_q = rms_norm( low_rank_q, q_norm_scale_weights, @@ -515,7 +658,7 @@ def query_projection( dtype=dtype, ) low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q") - q = dot(low_rank_q, wq_b_weights) + q = dot(low_rank_q, wq_b_weights, quant=quant) # Split into non-positional and rotary parts. q_nope, q_pe = jnp.split(q, [qk_nope_head_dim], axis=-1) @@ -554,9 +697,10 @@ def kv_projection( dtype, qk_nope_head_dim, num_query_heads, + quant, ): """Performs KV projection.""" - low_rank = dot(inputs, wkv_a_weights) + low_rank = dot(inputs, wkv_a_weights, quant=quant) low_rank_main, low_rank_rope = jnp.split(low_rank, [kv_lora_rank], axis=-1) low_rank_main = rms_norm( low_rank_main, @@ -585,12 +729,13 @@ def kv_projection( wkv_b_weights, qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, + quant=quant, ) -def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads): +def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads, quant): """Gets key and value from compressed KV latent vector and key rope.""" - kv_out = dot(low_rank_main, wkv_b_weights) + kv_out = dot(low_rank_main, wkv_b_weights, quant=quant) # Split kv_out into key_nope and value parts. key_nope, value = jnp.split(kv_out, [qk_nope_head_dim], axis=-1) @@ -686,6 +831,7 @@ def moe( expert_axis_name, use_gather_mosaic_kernel, config, + quant, ): """Performs dropless MoE with tensor/expert parallelism.""" xs, ys = list(zip(*inputs)) @@ -700,6 +846,7 @@ def moe( expert_axis_name=expert_axis_name, use_gather_mosaic_kernel=use_gather_mosaic_kernel, config=config, + quant=quant, ), mesh, ) @@ -730,9 +877,10 @@ def expert_selection( num_experts, num_experts_per_tok, routed_scaling_factor, + quant, ): """Selects experts for each token and calculates group sizes for each expert.""" - pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel)) + pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel, quant=quant)) logits = pre_bias_logits + routing_bias selected_experts, weights = expert_indices_and_weights( @@ -946,6 +1094,7 @@ def route_compute_unroute( use_gather_mosaic_kernel, config, mesh, + quant, ): """Routes, processes, and unroutes activations.""" orig_shape = xs[0].shape @@ -957,7 +1106,9 @@ def route_compute_unroute( def route_fn(inputs): # Shared expert. - y = dot(jax.nn.silu(dot(inputs, shared_w0)) * dot(inputs, shared_w1), shared_wo) + y = dot( + jax.nn.silu(dot(inputs, shared_w0, quant=quant)) * dot(inputs, shared_w1, quant=quant), shared_wo, quant=quant + ) inputs = jnp.reshape(inputs, (-1, inputs.shape[-1])) selected_experts, weights, group_sizes = expert_selection( @@ -967,6 +1118,7 @@ def route_fn(inputs): num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, routed_scaling_factor=routed_scaling_factor, + quant=quant, ) x, selected_experts, weights, group_sizes = route( inputs, @@ -1019,6 +1171,7 @@ def process_activations( expert_axis_name, use_gather_mosaic_kernel, config, + quant, ): """Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights.""" activation_pspec = jax.sharding.PartitionSpec( @@ -1043,6 +1196,7 @@ def process_activations( use_gather_mosaic_kernel=use_gather_mosaic_kernel, config=config, mesh=mesh, + quant=quant, ), mesh=mesh, in_specs=( diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 71a63c1ce2..00a2ab4418 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -248,6 +248,9 @@ def compute_loss( "distill/teacher_loss": teacher_hard_loss, "distill/out_proj_feature_loss": feature_loss, "distill/total_loss": total_loss, + "distill/temperature": self.temperature, + "distill/alpha": self.alpha, + "distill/beta_feature": self.beta_feature, } return total_loss, metrics diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 85eb045bfe..7f5541b69a 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -464,8 +464,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): ) # 4. Optimizer & Config - total_updates = student_config.steps // student_config.gradient_accumulation_steps - optimizer = get_distillation_optimizer(student_config, total_updates) + optimizer = get_distillation_optimizer(student_config, student_config.steps) checkpointing_options = checkpoint.CheckpointManagerOptions( save_interval_steps=student_config.checkpoint_period, diff --git a/src/maxtext/trainers/tokenizer/train_tokenizer.py b/src/maxtext/trainers/tokenizer/train_tokenizer.py index 30756e03b0..ba1e161e10 100644 --- a/src/maxtext/trainers/tokenizer/train_tokenizer.py +++ b/src/maxtext/trainers/tokenizer/train_tokenizer.py @@ -13,13 +13,25 @@ # limitations under the License. """ Train tokenizer -Example usage: python3 -m MaxText.train_tokenizer --dataset_path=gs://maxtext-dataset --dataset_name=c4/en:3.0.1 +Example usage (parquet): + python3 -m MaxText.train_tokenizer \ + --grain_train_files=gs://my-bucket/data/*.parquet \ + --grain_file_type=parquet + +Example usage (arrayrecord): + python3 -m MaxText.train_tokenizer \ + --grain_train_files=gs://my-bucket/data/*.arrayrecord \ + --grain_file_type=arrayrecord \ + --data_column=text """ +import glob import os -import sys +import shutil import tempfile import time +from collections.abc import Iterator +from pathlib import Path from absl import app from absl import flags @@ -28,44 +40,101 @@ from sentencepiece import SentencePieceTrainer import jax - -import tensorflow as tf -import tensorflow_datasets as tfds +import grain.python as grain from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT +from maxtext.utils import gcs_utils + -_DATASET_PATH = flags.DEFINE_string("dataset_path", None, "Path to the dataset", required=True) -_DATASET_NAME = flags.DEFINE_string("dataset_name", None, "Name to the dataset", required=True) +_GRAIN_TRAIN_FILES = flags.DEFINE_string( + "grain_train_files", None, "File pattern for training data (local or gs://)", required=True +) +_GRAIN_FILE_TYPE = flags.DEFINE_string( + "grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord'." +) +_DATA_COLUMN = flags.DEFINE_string("data_column", "text", "Column name to extract text from (used for arrayrecord).") _VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size") _MAX_CORPUS_CHARS = flags.DEFINE_integer("max_corpus_chars", 10_000_000, "Max corpus chars") -_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Name to the dataset") -_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Name to the dataset") +_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Path to assets directory") +_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Output tokenizer model name") + + +def build_grain_iterator(data_file_pattern: str, data_file_type: str, data_keys: tuple[str, ...] = ("text",)) -> Iterator: + """Build a grain iterator from a file pattern for tokenizer training. + + Args: + data_file_pattern: Glob pattern for data files (local path or gs://). + data_file_type: One of 'arrayrecord' or 'parquet'. + data_keys: Column names to extract from each example (used for arrayrecord). + + Returns: + A Python iterator yielding examples as dicts. + """ + if data_file_pattern.startswith("gs://"): + data_files = gcs_utils.gcs_glob_pattern(data_file_pattern) + else: + data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve())) + if not data_files: + raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}") + logging.info("Found %d files for tokenizer training.", len(data_files)) + + if data_file_type == "parquet": + dataset = grain.MapDataset.source(data_files) + dataset = dataset.map(grain.experimental.ParquetIterDataset) + dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files)) + return iter(dataset) + elif data_file_type == "arrayrecord": + from maxtext.input_pipeline.protos import example_pb2 # pylint: disable=import-outside-toplevel + + source = grain.ArrayRecordDataSource(data_files) + dataset = grain.MapDataset.source(source) + + def _parse_example(raw_bytes): + example = example_pb2.Example() + example.ParseFromString(raw_bytes) + features = example.features.feature + parsed = {} + for col in data_keys: + if col in features: + parsed[col] = features[col].bytes_list.value[0] + return parsed + + dataset = dataset.map(_parse_example) + return iter(dataset) + else: + raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet' or 'arrayrecord'.") + +def _dump_chars_to_textfile(dataset_iter: Iterator, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]: + """Write part of a grain dataset to lines in a text file. -def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]: - """Write part of a TFDS sentence dataset to lines in a text file. Args: - dataset: tf.dataset containing string-data. - maxchars: int: approximate number of characters to save from dataset. - data_keys: tuple[str]: what keys in dataset to dump from. + dataset_iter: Iterator yielding examples as dicts. + maxchars: Approximate number of characters to save from dataset. + data_keys: Keys in each example to dump. + Returns: - name of temp file with dataset bytes, exact number of characters dumped. + Name of temp file with dataset bytes, exact number of characters dumped. """ char_count = 0 - ds_iter = dataset.as_numpy_iterator() temp_dir = tempfile.gettempdir() - with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "ds_chars")) as outfp: + with tempfile.NamedTemporaryFile( + delete=False, prefix=os.path.join(temp_dir, "ds_chars"), mode="w", encoding="utf-8" + ) as outfp: while char_count < maxchars: - example = next(ds_iter) + example = next(dataset_iter) for k in data_keys: - line = example[k] + b"\n" + val = example[k] + if isinstance(val, bytes): + val = val.decode("utf-8") + line = val + "\n" char_count += len(line) outfp.write(line) return outfp.name, char_count def _train_sentencepiece( - dataset: tf.data.Dataset, + dataset_iter: Iterator, *, vocab_size: int, maxchars: int = int(1e7), @@ -74,25 +143,25 @@ def _train_sentencepiece( character_coverage: float = 1.0, data_keys=("text",), ): - """Train SentencePiece tokenizer from subset of tf dataset. + """Train SentencePiece tokenizer from subset of a grain dataset. + Args: - dataset: tf.dataset - vocab_size: int: size of vocab tokens to train. - maxchars: int: number of characters to use for sentencepiece training. - model_path: str: path of model file to save vocab model to. - model_type: str: type of sentencepiece vocab to train. - character_coverage: amount of characters covered by the model, good defaults - are 0.9995 for languages with rich character set like Japanese or Chinese - and 1.0 for other languages with small character set. - data_keys: tuple[str]: keys of dataset to use for training. + dataset_iter: Iterator yielding examples as dicts. + vocab_size: Size of vocab tokens to train. + maxchars: Number of characters to use for sentencepiece training. + model_path: Path to save vocab model to (local or gs://). + model_type: Type of sentencepiece vocab to train. + character_coverage: Amount of characters covered by the model. + data_keys: Keys of dataset to use for training. + Returns: - path to the trained sentencepiece vocabulary model. + Path to the trained sentencepiece vocabulary model. """ if model_path.startswith("gs://"): abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) - fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) + fname, _ = _dump_chars_to_textfile(dataset_iter, maxchars=maxchars, data_keys=data_keys) temp_dir = tempfile.gettempdir() with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "sp_tmp")) as model_fp: pass # we just want a prefix'd tmp-filename @@ -107,32 +176,38 @@ def _train_sentencepiece( ) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: - # Use an intermediate filename that is renamed to the target name to address - # create and fill delays. - copy_rename_path = abs_model_path + ".rntmp" - tf.io.gfile.makedirs(os.path.dirname(abs_model_path)) - tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True) - tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) - logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path) + if abs_model_path.startswith("gs://"): + gcs_utils.upload_blob(abs_model_path, model_fp.name + ".model") + logging.info("Uploaded %s to %s", model_fp.name + ".model", abs_model_path) + else: + parent = os.path.dirname(abs_model_path) + if parent: + os.makedirs(parent, exist_ok=True) + shutil.copy(model_fp.name + ".model", abs_model_path) + logging.info("Copied %s to %s", model_fp.name + ".model", abs_model_path) else: - while not tf.io.gfile.exists(abs_model_path): - time.sleep(1) + if abs_model_path.startswith("gs://"): + while not gcs_utils.gcs_path_exists(abs_model_path): + time.sleep(1) + else: + while not os.path.exists(abs_model_path): + time.sleep(1) time.sleep(1) return abs_model_path def train_tokenizer( - dataset: tf.data.Dataset, + dataset_iter: Iterator, *, vocab_path: str, vocab_size: int, max_corpus_chars: int, data_keys: tuple[str] = ("text",), ): - """tokenizer training function""" + """Tokenizer training function.""" logging.info("SentencePiece vocab not found, building one from data.") vocab_path = _train_sentencepiece( - dataset, + dataset_iter, vocab_size=vocab_size, maxchars=max_corpus_chars, model_path=vocab_path, @@ -143,19 +218,14 @@ def train_tokenizer( def main(argv): del argv - flags.FLAGS(sys.argv) - os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value - - read_config = tfds.ReadConfig( - shuffle_seed=0, - ) - train_ds_builder = tfds.builder(_DATASET_NAME.value) - train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) + data_keys = (_DATA_COLUMN.value,) + dataset_iter = build_grain_iterator(_GRAIN_TRAIN_FILES.value, _GRAIN_FILE_TYPE.value, data_keys=data_keys) train_tokenizer( - train_ds, + dataset_iter, vocab_path=os.path.join(_ASSETS_PATH.value, _VOCAB_MODEL_NAME.value), vocab_size=_VOCAB_SIZE.value, max_corpus_chars=_MAX_CORPUS_CHARS.value, + data_keys=data_keys, ) diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index e059986162..9b4249a8f7 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -399,6 +399,9 @@ def _test_monitored_strategy(self, sft_mode: bool): "distill/teacher_loss", "distill/out_proj_feature_loss", "distill/total_loss", + "distill/temperature", + "distill/alpha", + "distill/beta_feature", ] for key in expected_keys: self.assertIn(key, metrics) diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index e5a4aa8ec2..4bb9831f60 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -331,6 +331,7 @@ def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files train_fraction=1.0, num_epoch=1, num_test_batches=1, + test_batch_start_index=0, ) # Patch everything! diff --git a/tests/unit/tokenizer_test.py b/tests/unit/tokenizer_test.py index e58387c23a..59767f6a55 100644 --- a/tests/unit/tokenizer_test.py +++ b/tests/unit/tokenizer_test.py @@ -21,18 +21,17 @@ import unittest import pytest -import tensorflow_datasets as tfds import subprocess import os -class TokenizerTest(unittest.TestCase): - """Tests for train_tokenizer.py""" +class TrainTokenizerTest(unittest.TestCase): + """Tests for train_tokenizer.py using data from Parquet files""" @classmethod def setUpClass(cls): - dataset_name = "c4/en:3.0.1" - dataset_path = "gs://maxtext-dataset" + # the test only use ~10Mb of data, one file is enough, more files cause slow down + grain_train_files = "gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet" cls.vocab_size = 32_768 cls.max_corpus_chars = 10_000_000 assets_path = "tests" @@ -44,14 +43,9 @@ def setUpClass(cls): add_bos=False, add_eos=False, ) - os.environ["TFDS_DATA_DIR"] = dataset_path - read_config = tfds.ReadConfig( - shuffle_seed=0, - ) - train_ds_builder = tfds.builder(dataset_name) - cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) + dataset_iter = train_tokenizer.build_grain_iterator(grain_train_files, "parquet") train_tokenizer.train_tokenizer( - cls.dataset, + dataset_iter, vocab_path=cls.tokenizer_path, vocab_size=cls.vocab_size, max_corpus_chars=cls.max_corpus_chars, @@ -76,24 +70,18 @@ def test_detokenize(self): class TikTokenTest(unittest.TestCase): - """Tests for train_tokenizer.py""" + """Tests for TikToken""" @classmethod def setUpClass(cls): - dataset_name = "c4/en:3.0.1" - dataset_path = "gs://maxtext-dataset" + grain_train_files = "gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet" cls.source_tokenizer = input_pipeline_utils.get_tokenizer( os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"), "tiktoken", add_bos=False, add_eos=False, ) - os.environ["TFDS_DATA_DIR"] = dataset_path - read_config = tfds.ReadConfig( - shuffle_seed=0, - ) - train_ds_builder = tfds.builder(dataset_name) - cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) + cls.dataset = train_tokenizer.build_grain_iterator(grain_train_files, "parquet") @pytest.mark.tpu_only def test_tokenize(self):