diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index eef22245bd..7edbd15314 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -108,6 +108,8 @@ class DecoderBlockType(enum.Enum): LLAMA4 = "llama4" OLMO3 = "olmo3" + LLAMA2LTI = "llama2-learn-to-init" + class AttentionType(enum.Enum): GLOBAL = "global" # default, with causality diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index af809e2b42..14e9815379 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1194,6 +1194,26 @@ class Distillation(BaseModel): "constant", description="Schedule type for beta annealing ('constant', 'linear', or 'cosine')." ) + # --- Learn to init related parameters -- + learn_to_init_mode: bool = Field(False, description="Runs in the learn-to-init mode only") + + lti_use_general_linear_map: bool = Field( + False, + description="enable general map (i.e. single learnable projection instead of the bi-linear mapping. " + "Needs much more HBM.", + ) + + distill_weights_copy_map: dict[str, Any] = Field( + default_factory=dict, + description="Dictionary of copying original teacher weights to the student model.", + ) + + distill_student_weights_share_map: dict[str, Any] = Field( + default_factory=dict, + description="Experimental weight sharing map inside the student model for learn-to-init phase", + ) + # --------------------------------------- + # --- Distillation freezing filter -- student_params_to_update: None | list = Field( None, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index eb5630968f..021a5f8c29 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -449,6 +449,8 @@ def get_decoder_layers(self): return [DecoderLayer] case DecoderBlockType.LLAMA2: return [llama2.LlamaDecoderLayerToLinen] + case DecoderBlockType.LLAMA2LTI: + return [llama2.LlamaLTIDecoderLayerToLinen] case DecoderBlockType.MISTRAL: # TODO(ranran): update to Mistral with sliding window attention return [mistral.MistralDecoderLayerToLinen] @@ -543,6 +545,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.SIMPLE_MLP, DecoderBlockType.LLAMA4, DecoderBlockType.OLMO3, + DecoderBlockType.LLAMA2LTI, ): return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) elif self.config.decoder_block == DecoderBlockType.GPT3: diff --git a/src/maxtext/layers/learn_to_init_layer.py b/src/maxtext/layers/learn_to_init_layer.py new file mode 100644 index 0000000000..d36bd98780 --- /dev/null +++ b/src/maxtext/layers/learn_to_init_layer.py @@ -0,0 +1,439 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""nxx module overrides and utility methods for LTI distillation""" + +import jax +from flax import nnx +from maxtext.layers import linears, initializers +from maxtext.common.common_types import Config +from jax.sharding import Mesh, NamedSharding +import jax.numpy as jnp +from typing import Iterable, Optional + +from maxtext.common.common_types import DType, ShardMode, Array +from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.layers.initializers import NdInitializer, nd_dense_init +from maxtext.utils import max_logging, max_utils + + +class LearnToInitDecoderLayer(nnx.Module): + """ + A generic wrapper that initializes a base decoder layer and dynamically swaps + its DenseGeneral modules for learn-to-init distillation. + + This class instantiates a standard base decoder layer (e.g., LlamaDecoderLayer) + and replaces specific attention projection sub-modules ("query", "key", "value", + "out") with customized `LearnToInitDense` modules. + + Attributes: + learn_to_init_wrapper: The instantiated base decoder layer containing the mutable NNX graph. + config: The model configuration parameters. + rngs: The random number generator state used for initialization. + self_attention_module_name: The target name of the attention module to customize. + """ + + def __init__( + self, + base_layer_cls, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant=None, + **kwargs, + ): + # Instantiate the original layer (e.g., LlamaDecoderLayer) + self.learn_to_init_wrapper = base_layer_cls( + config=config, model_mode=model_mode, mesh=mesh, rngs=rngs, quant=quant, **kwargs + ) + + self.config = config + self.rngs = rngs + + self.self_attention_module_name = "self_attention" + + # replace relevant nnx modules with customized LearnToInit modules + self._customize_attention_modules(self.learn_to_init_wrapper) + + def _customize_attention_modules(self, module: nnx.Module): + """Replaces specific DenseGeneral modules (q, k, v projections) in the attention module.""" + attention_module = getattr(module, self.self_attention_module_name, None) + if attention_module is None: + return + + # Target Q, K, V projections sub module names + target_names = ["query", "key", "value", "out"] + + use_general_linear_map = self.config.lti_use_general_linear_map + teacher_config = self.config.teacher_config + + for name in target_names: + child = getattr(attention_module, name, None) + if isinstance(child, linears.DenseGeneral): + orig_proj_shape = child.kernel.shape + assert len(orig_proj_shape) == 3 + if name in ("query", "key", "value"): + teacher_heads_num = teacher_config.base_num_query_heads if name == "query" else teacher_config.base_num_kv_heads + teacher_shape = (orig_proj_shape[0], teacher_heads_num, teacher_config.head_dim) + elif name == "out": + teacher_shape = (teacher_config.base_num_query_heads, teacher_config.head_dim, orig_proj_shape[2]) + else: + max_logging.warning(f"Non handled LTI projection type {name}") + continue + new_module = LearnToInitDense( + in_features_shape=child.in_features_shape, + out_features_shape=child.out_features_shape, + C=jnp.empty(teacher_shape), + axis=child.axis, + weight_dtype=child.weight_dtype, + dtype=child.dtype, + kernel_init=child.kernel_init, + kernel_axes=child.kernel_axes, + quant=child.quant, + use_bias=child.use_bias, + shard_mode=child.shard_mode, + matmul_precision=child.matmul_precision, + is_output_projection=(name == "out"), + use_general_linear_map=use_general_linear_map, + rngs=self.rngs, # Reuse the layer's RNG stream + ) + # Swap the module in the mutable NNX graph + setattr(attention_module, name, new_module) + + def __call__(self, *args, **kwargs): + # Just forward the forward pass arguments to the base layer + return self.learn_to_init_wrapper(*args, **kwargs) + + +class LearnToInitDense(nnx.Module): + """ + A customized Dense layer used exclusively during the learn-to-init phase of distillation. + + This module replaces standard `DenseGeneral` projections within the attention mechanism. + Instead of a single standard kernel, it computes the effective projection weights + dynamically during the forward pass by combining learnable student parameters + (either A and B matrices, or a general linear map W) with frozen teacher weights (C). + + The projection math adapts automatically based on whether the layer is used for + Q/K/V projections or the final output projection. + + Attributes: + C: The frozen, pre-trained teacher tensor. + A: The first learnable projection matrix (used if use_general_linear_map is False). + B: The second learnable projection matrix (used if use_general_linear_map is False). + W: A single, general learnable linear map (used if use_general_linear_map is True). + bias: An optional learnable bias parameter. + """ + + TENSOR_A = "A" + TENSOR_B = "B" + TENSOR_C = "C" + TENSOR_W = "W" + + def __init__( + self, + in_features_shape: Iterable[int] | int, + out_features_shape: Iterable[int] | int, + C: Optional[jax.Array] = None, # C is assumed to be the teacher tensor + axis: Iterable[int] | int = -1, + weight_dtype: DType = jnp.float32, + is_output_projection: bool = False, + use_general_linear_map: bool = False, + dtype: DType = jnp.float32, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes: tuple[None | str, ...] = (), + quant: None | Quant = None, + use_bias: bool = False, + shard_mode: ShardMode = ShardMode.AUTO, + matmul_precision: str = "default", + parameter_memory_host_offload: bool = False, + *, # Following arguments are keyword-only + rngs: nnx.Rngs = None, + ): + self.in_features_shape = linears.canonicalize_tuple(in_features_shape) + self.out_features_shape = linears.canonicalize_tuple(out_features_shape) + self.axis = linears.canonicalize_tuple(axis) + self.weight_dtype = weight_dtype + self.is_output_projection = is_output_projection + self.dtype = dtype + self.kernel_init = kernel_init + self.kernel_axes = kernel_axes + self.quant = quant + self.use_bias = use_bias + self.shard_mode = shard_mode + self.matmul_precision = matmul_precision + self.parameter_memory_host_offload = parameter_memory_host_offload + self.use_general_linear_map = use_general_linear_map + + self.C = nnx.Param(C, sharding=self.kernel_axes) + + kernel_shape = self.in_features_shape + self.out_features_shape + assert len(kernel_shape) == 3, "LearnToInitDense currently only supports 3D kernels for attention." + assert len(self.C.value.shape) == 3, "The teacher tensor C must be 3D." + + if self.is_output_projection: + # For output projection: student(u,v,b_s), teacher(x,y,b_t) + u, v, b_s = kernel_shape + x, y, b_t = self.C.value.shape + assert b_s == b_t, f"Embedding dimension mismatch for output projection: {b_s} != {b_t}" + if self.use_general_linear_map: + self.W = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (x, y, u, v), self.weight_dtype), + sharding=(None, None, None, None), + ) + else: + self.A = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (x, u), self.weight_dtype), + sharding=(None, None), + ) + self.B = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (v, y), self.weight_dtype), + sharding=(None, None), + ) + else: + # For Q,K,V projections: student(b_s,u,v), teacher(b_t,x,y) + b_s, u, v = kernel_shape + b_t, x, y = self.C.value.shape + + assert b_s == b_t, f"Dimension mismatch for QKV projection: {b_s} != {b_t}" + if self.use_general_linear_map: + self.W = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (x, y, u, v), self.weight_dtype), + sharding=(None, None, None, None), + ) + else: + self.A = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (x, u), self.weight_dtype), + sharding=(None, None), + ) + self.B = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (y, v), self.weight_dtype), + sharding=(None, None), + ) + + if self.use_bias: + bias_axes = self.kernel_axes[-len(self.out_features_shape) :] + bias_shape = self.out_features_shape + self.bias = nnx.Param( + initializers.default_bias_init(rngs.params(), bias_shape, self.weight_dtype), + sharding=bias_axes, + ) + else: + self.bias = None + + def __call__(self, inputs: Array, _initializing: bool = False, out_sharding: NamedSharding | None = None) -> Array: + inputs = jnp.asarray(inputs, self.dtype) + norm_axis = linears.normalize_axes(self.axis, inputs.ndim) + + for i, ax in enumerate(norm_axis): + if inputs.shape[ax] != self.in_features_shape[i]: + raise ValueError( + f"Input dimension {inputs.shape[ax]} at axis {ax} " + f"does not match expected input feature size {self.in_features_shape[i]}" + ) + + if self.C.value.shape[0] == 0: + raise ValueError( + "The 'C' tensor in LearnToInitDense has not been initialized. " + "Please inject the teacher weights before training." + ) + + if self.use_general_linear_map: + kernel = _calc_attn_weight( + None, + None, + self.C, + general_map=self.W, + is_output_projection=self.is_output_projection, + matmul_precision=self.matmul_precision, + ) + else: + kernel = _calc_attn_weight( + self.A, self.B, self.C, is_output_projection=self.is_output_projection, matmul_precision=self.matmul_precision + ) + + if self.parameter_memory_host_offload: + max_logging.log("linear.py: Moving parameter logits_dense kernel to device") + kernel = jax.device_put(kernel, max_utils.device_space()) + kernel = jnp.asarray(kernel, self.dtype) + + # out_sharding should be None for auto mesh axis + if self.shard_mode != ShardMode.EXPLICIT: + out_sharding = None + + contract_ind = tuple(range(0, len(self.axis))) + output = linears._compute_dot_general_nnx( + inputs, + kernel, + norm_axis, + contract_ind, + self.matmul_precision, + None, + _initializing, + out_sharding, + ) + + if self.bias is not None: + bias = jnp.asarray(self.bias[...], self.dtype) + output += bias + return output + + +def _calc_attn_weight( + A: jax.Array | nnx.Param | None, + B: jax.Array | nnx.Param | None, + C: jax.Array | nnx.Param | None, + general_map: Optional[jax.Array | nnx.Param] = None, + is_output_projection: bool = False, + matmul_precision: str = "default", + scan_dim: str = "", +): + """Computes the effective attention weights from teacher weight and learnable projection(s). + See the description of calculate_attn_weight() below for details. + """ + if general_map is not None: + if is_output_projection: + kernel = jnp.einsum(f"x{scan_dim}yb,x{scan_dim}yuv->u{scan_dim}vb", C, general_map, precision=matmul_precision) + else: + kernel = jnp.einsum(f"b{scan_dim}xy,x{scan_dim}yuv->b{scan_dim}uv", C, general_map, precision=matmul_precision) + return kernel + + if is_output_projection: + intermediate = jnp.einsum(f"x{scan_dim}yb,x{scan_dim}u->y{scan_dim}ub", C, A, precision=matmul_precision) + kernel = jnp.einsum(f"y{scan_dim}ub,v{scan_dim}y->u{scan_dim}vb", intermediate, B, precision=matmul_precision) + else: + intermediate = jnp.einsum(f"b{scan_dim}xy,x{scan_dim}u->b{scan_dim}uy", C, A, precision=matmul_precision) + kernel = jnp.einsum(f"b{scan_dim}uy,y{scan_dim}v->b{scan_dim}uv", intermediate, B, precision=matmul_precision) + return kernel + + +def calculate_attn_weight( + A: jax.Array | None, + B: jax.Array | None, + C: jax.Array, + general_map: Optional[jax.Array] = None, + is_output_projection: bool = False, + matmul_precision: str = "default", +) -> jax.Array: + """ + Helper function to dynamically compute the effective attention weights using `jnp.einsum`. + + Computes the kernel by contracting the frozen teacher tensor (C) with the learnable + student representations. It handles both factorized maps (A and B) and general linear + maps (general_map/W), adjusting the tensor contractions based on whether the module + is an output projection or a Q/K/V projection. + + Args: + A: The first learned factorized matrix. + B: The second learned factorized matrix. + C: The frozen teacher tensor. + general_map: An optional unified learnable projection tensor used instead of A and B. + is_output_projection: Boolean flag indicating if this computes the output projection weight. + matmul_precision: The precision for the einsum matrix multiplications. + scan_dim: A string representing the scan dimension for einsum (e.g., "l" for scanned layers, or ""). + + Returns: + The computed effective kernel tensor. + """ + + # In scan mode, tensors have an extra 2-nd dimension for the layer. + # We add 'l' to the einsum string to handle this batch dimension. + scan_dim = "l" if C.ndim == 4 else "" + return _calc_attn_weight( + A, + B, + C, + general_map=general_map, + is_output_projection=is_output_projection, + matmul_precision=matmul_precision, + scan_dim=scan_dim, + ) + + +def apply_lti_model_update(student_model, student_config): + """ + Applies the finalized learn-to-init weights to the student model and cleans up the NNX graph. + + This function iterates over the `LearnToInitDense` layers in the trained student model, + calculates their final, static effective kernels using `calculate_attn_weight`, and + replaces the dynamically-computed LTI modules with standard kernel representations. + It effectively collapses the learn-to-init parameterization back into a standard + decoder architecture, modifying the `student_model` in-place. + + NOTE: works for ToNXX decoder model and layer-scan mode only + + Args: + student_model: The trained student model to be updated in-place. + student_config: The configuration of the student model containing parameters like `matmul_precision`. + """ + + # Access the nested ToNNX dictionary directly + lti_wrapped_node = student_model.decoder.layers["learn_to_init_wrapper"] + attn_state_dict = lti_wrapped_node["self_attention"] + + # Iterate through known projections and compute final weights + for proj_name in ["query", "key", "value", "out"]: + if proj_name not in attn_state_dict: + raise ValueError("Unsupported structure of LTI-augmented Attention module.") + + proj_params = attn_state_dict[proj_name] + is_output_proj = proj_name == "out" + + C_param = proj_params.get(LearnToInitDense.TENSOR_C) + + if C_param is None: + raise ValueError("Attention LTI-augmented module has no C parameter.") + + if LearnToInitDense.TENSOR_W in proj_params: + max_logging.log(f"Computing final learn-to-init weight (general map) for: {proj_name}") + W_param = proj_params[LearnToInitDense.TENSOR_W] + final_kernel = calculate_attn_weight( + A=None, + B=None, + C=C_param, + general_map=W_param, + is_output_projection=is_output_proj, + matmul_precision=student_config.matmul_precision, + ) + elif LearnToInitDense.TENSOR_A in proj_params and LearnToInitDense.TENSOR_B in proj_params: + max_logging.log(f"Computing final learn-to-init weight for: {proj_name}") + A_param = proj_params[LearnToInitDense.TENSOR_A] + B_param = proj_params[LearnToInitDense.TENSOR_B] + final_kernel = calculate_attn_weight( + A=A_param, + B=B_param, + C=C_param, + is_output_projection=is_output_proj, + matmul_precision=student_config.matmul_precision, + ) + else: + continue + + # 3. Overwrite C with the final computed kernel + C_param.set_value(final_kernel) + + # 4. Standardize the structure by placing it under the 'kernel' key + proj_params["kernel"] = C_param + + # 5. Clean up the LTI-specific parameters using .pop() + # Using pop(key, None) avoids KeyErrors if a tensor was omitted or already shared/deleted + proj_params.pop(LearnToInitDense.TENSOR_W, None) + proj_params.pop(LearnToInitDense.TENSOR_A, None) + proj_params.pop(LearnToInitDense.TENSOR_B, None) + proj_params.pop(LearnToInitDense.TENSOR_C, None) + + # unpack the learn_to_init_wrapper to match the standard model structure + del student_model.decoder.layers["learn_to_init_wrapper"] + student_model.decoder.layers.update(lti_wrapped_node) diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 3f1036dbd4..eb81d596d9 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -614,7 +614,10 @@ def __init__( # Set the class name correctly to avoid issues like ScanToLinenPartial_0 # Instead of ToLinenPartial_0, we can use the base class name + 'ToLinen' - class_name = f"{base_nnx_class.__name__}ToLinen" + if isinstance(base_nnx_class, partial): + class_name = f"{base_nnx_class.func.__name__}ToLinen" + else: + class_name = f"{base_nnx_class.__name__}ToLinen" class ToLinenPartial(ToLinen): """A dynamically created Linen Module that wraps a specific NNX Module.""" diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 9ae60abe25..3f18ab8a61 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -33,6 +33,7 @@ from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical +from maxtext.layers.learn_to_init_layer import LearnToInitDecoderLayer # ----------------------------------------- # The Decoder Layer specific for Llama2 @@ -50,7 +51,6 @@ def __init__( rngs: nnx.Rngs, quant: None | Quant = None, ): - self.config = config self.mesh = mesh self.quant = quant @@ -212,6 +212,13 @@ def __call__( return layer_output, kv_cache +LlamaLTIDecoderLayer = functools.partial(LearnToInitDecoderLayer, base_layer_cls=LlamaDecoderLayer) + +LlamaLTIDecoderLayerToLinen = nnx_wrappers.to_linen_class( + LlamaLTIDecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + LlamaDecoderLayerToLinen = nnx_wrappers.to_linen_class( LlamaDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 9cb233bab1..83fff343ff 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -600,9 +600,11 @@ def __init__( self, raw_iterator: Any | None, root_directory: str | None = None, + student_config: Any | None = None, options: checkpoint.CheckpointManagerOptions | None = None, ): super().__init__(root_directory=root_directory, options=options) + self.student_config = student_config self._iterator = raw_iterator # Re-initialize internal Orbax manager with MaxText's Grain handler @@ -629,7 +631,15 @@ def __init__( ) # pylint: enable=access-member-before-definition - def save(self, step, model, optimizer=None, save_only_lora_params=False, force=False, custom_metadata=None): + def save( + self, + step, + model, + optimizer=None, + save_only_lora_params=False, + force=False, + custom_metadata=None, + ): """Saves the checkpoint including the input pipeline state (if available).""" if self._checkpoint_manager is None: return False @@ -651,7 +661,10 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F item=params, save_args=jax.tree.map(lambda _: default_save_args, params) ), } - if optimizer is not None: + # Exclude optimizer state if the flag is set OR if learn_to_init_mode is active. + exclude_opt = self.student_config.learn_to_init_mode + + if optimizer is not None and not exclude_opt: optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState) cp_save_args["optimizer_state"] = checkpoint.args.PyTreeSave( item=optimizer_state, save_args=jax.tree.map(lambda _: default_save_args, optimizer_state) diff --git a/src/maxtext/trainers/post_train/distillation/lti_utils.py b/src/maxtext/trainers/post_train/distillation/lti_utils.py new file mode 100644 index 0000000000..d618d1873c --- /dev/null +++ b/src/maxtext/trainers/post_train/distillation/lti_utils.py @@ -0,0 +1,117 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Learn-to-Init phase of the distillation""" + +from flax import nnx +from maxtext.utils import max_logging + + +def _nav_to_attr(root, parts): + """Walk `root.parts[0].parts[1]...` and return the terminal Python object. + + Handles both attribute access and mapping/sequence access (integer-looking + components are tried as sequence indices as a last resort) so string paths + from `nnx.graph.iter_graph` round-trip across scanned and unscanned models. + """ + obj = root + for p in parts: + if hasattr(obj, p): + obj = getattr(obj, p) + continue + try: + obj = obj[p] + continue + except (KeyError, TypeError, IndexError): + pass + try: + obj = obj[int(p)] + continue + except (ValueError, TypeError, KeyError, IndexError): + pass + raise AttributeError( + f"prepare_student_weights: cannot resolve path component {p!r} on " + f"{type(obj).__name__}. Make sure copy/share map paths match " + f"nnx.graph.iter_graph output for this model." + ) + return obj + + +def prepare_student_weights( + student_model: nnx.Module, + teacher_model: nnx.Module, + teacher_weights_copy_map: dict[str, str], + student_weights_share_map: dict[str, str], +): + """Injects weights from a teacher model into a student model in-place as well as + shares specific pairs of weights inside the student model - might be useful for learn-to-init experiments. + + This function iterates through the provided mapping dictionaries and + copies the corresponding weights from the teacher. + + It works by matching the graph path of a module in the student to the same + path in the teacher. + + Args: + student_model: The student model (NNX Module), which will be modified. + teacher_model: The teacher model (NNX Module), used as the source. + teacher_weights_copy_map: A dictionary mapping teacher parameter paths (as strings) + to student parameter paths. + student_weights_share_map: A dictionary mapping student parameter paths to be shared. + """ + max_logging.log("Starting teacher weight injection...") + + # Get a dictionary view of the teacher graph for efficient lookups + teacher_graph = {"/".join(map(str, path)): node for path, node in nnx.graph.iter_graph(teacher_model)} + student_graph = {"/".join(map(str, path)): node for path, node in nnx.graph.iter_graph(student_model)} + + # --- Weight sharing (alias destination -> source Variable) --- + for source_path, dest_path in student_weights_share_map.items(): + source_node = student_graph.get(source_path) + dest_node = student_graph.get(dest_path) + assert ( + source_node is not None + ), f"Student parameter sharing: Could not find source_node model parameter at path: {source_path}" + assert ( + dest_node is not None + ), f"Student parameter sharing: Could not find dest_node model parameter at path: {dest_path}" + + assert source_node.value.shape == dest_node.value.shape, ( + f"Shape mismatch for sharing parameter between {source_path} and {dest_path}: " + f"{source_node.value.shape} vs {dest_node.value.shape}" + ) + + max_logging.info(f"Sharing parameter {source_path} with {dest_path}") + dest_parts = dest_path.split("/") + + dest_parent = _nav_to_attr(student_model, dest_parts[:-1]) + dest_attr = dest_parts[-1] + + if hasattr(dest_parent, dest_attr): + setattr(dest_parent, dest_attr, source_node) + else: + dest_parent[dest_attr] = source_node + + for teacher_path, student_path in teacher_weights_copy_map.items(): + teacher_node = teacher_graph.get(teacher_path) + student_node = student_graph.get(student_path) + assert teacher_node is not None, f"Could not find teacher model parameter at path: {teacher_path}" + assert student_node is not None, f"Could not find student model parameter at path: {student_path}" + assert ( + student_node.value.shape == teacher_node.value.shape + ), f"Shape mismatch for {teacher_path}. Teacher: {teacher_node.value.shape}, Student: {student_node.value.shape}" + student_node.value = teacher_node.value + max_logging.info(f"Inserted teacher weight parameter {teacher_path} to the student at {student_path}") + + max_logging.info("Teacher weight injection complete.") diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index b0a6396118..5e557610ed 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -48,8 +48,9 @@ from maxtext.configs import pyconfig from maxtext.input_pipeline import tokenizer from maxtext.input_pipeline import input_pipeline_interface +from maxtext.layers.learn_to_init_layer import apply_lti_model_update from maxtext.optimizers import optimizers -from maxtext.trainers.post_train.distillation import distillation_utils +from maxtext.trainers.post_train.distillation import distillation_utils, lti_utils from maxtext.utils import max_logging from maxtext.utils import maxtext_utils from maxtext.utils import model_creation_utils @@ -178,7 +179,7 @@ def _log_config_details(config: pyconfig.HyperParameters, label: str) -> None: class ModelBundle(nnx.Module): """Wrapper for teacher and student modules.""" - def __init__(self, teacher_model: nnx.Module, student_model: nnx.Module): + def __init__(self, teacher_model: nnx.Module | None, student_model: nnx.Module): self.teacher_model = teacher_model self.student_model = student_model self.training_step = nnx.Variable(jnp.zeros((), dtype=jnp.int32)) @@ -366,6 +367,7 @@ def _post_process_train_step(self, aux: dict[str, tuple[jax.Array, jax.Array]]) self._buffered_train_metrics.additional_metrics[name] = ([], distillation_utils.weighted_mean) self._buffered_train_metrics.additional_metrics[name][0].append(value) + max_logging.log(f"Distillation metrics: {aux}") def setup_checkpoint_manager_and_restore(self, raw_train_iter, config): """Configures the trainer's CheckpointManager and restores states. @@ -410,6 +412,7 @@ def setup_checkpoint_manager_and_restore(self, raw_train_iter, config): self.checkpoint_manager = distillation_utils.MaxTextCheckpointManager( raw_iterator=iterator_to_manage, root_directory=config.checkpoint_dir, + student_config=config, # Pass the config here options=self.config.checkpointing_options, ) @@ -581,24 +584,35 @@ def train_distill( with mesh, nn_partitioning.axis_rules(student_config.logical_axis_rules): # 2. Load Models + if is_offline: + max_logging.log("Offline Distillation: Skipping Teacher Model loading.") + teacher_model = None + else: + max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") + _log_config_details(teacher_config, "Teacher") + teacher_model = get_maxtext_model(teacher_config, mesh) + teacher_model.eval() + + # LTI phase needs the student initialization step to know about the teacher configuration + student_config.get_keys()["teacher_config"] = teacher_config + max_logging.log(f"Loading Student from {student_config.load_parameters_path}...") _log_config_details(student_config, "Student") student_model = get_maxtext_model(student_config, mesh) - student_params_to_update = getattr(student_config, "student_params_to_update", []) def student_freeze_param_fn(path) -> bool: path_str = "/".join(str(p) for p in path) return not any(template in path_str for template in student_params_to_update) - if is_offline: - max_logging.log("Offline Distillation: Skipping Teacher Model loading.") - teacher_model = None - else: - max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") - _log_config_details(teacher_config, "Teacher") - teacher_model = get_maxtext_model(teacher_config, mesh) - teacher_model.eval() + # Inject the teacher's frozen weights into the student model + if teacher_model: + lti_utils.prepare_student_weights( + student_model, + teacher_model, + teacher_weights_copy_map=getattr(student_config, "distill_weights_copy_map", {}), + student_weights_share_map=getattr(student_config, "distill_student_weights_share_map", {}), + ) student_model.train() model_bundle = ModelBundle(teacher_model, student_model) @@ -676,6 +690,11 @@ def custom_gen_model_input_fn(batch): # Pass both iterators to the trainer trainer.train(train_iter, eval_iter) + if student_config.learn_to_init_mode: + # If learn_to_init_mode is enabled, generate the final weights and update the model structure + max_logging.log("Learn-to-init mode enabled. Finalizing student model...") + apply_lti_model_update(student_model, student_config) + # 9. Final Save (Conditional) if student_config.save_checkpoint_on_completion: should_save = student_config.steps % student_config.checkpoint_period @@ -683,8 +702,12 @@ def custom_gen_model_input_fn(batch): if should_save: max_logging.log(f"Saving final checkpoint to {student_config.checkpoint_dir}...") try: + # TODO: tmp solution for learn_to_init_mode - we need to save the changed model checkpoint, + # force=True doesn't work and orbax can keep skip saving the most recent model + # temporal hack is to simply bump the step number + # you are supppsed to run regular distillation from scratch afterwards anyway saved = trainer.checkpoint_manager.save( - trainer.train_steps, + trainer.train_steps + (1 if student_config.learn_to_init_mode else 0), trainer.model, optimizer=trainer.optimizer, save_only_lora_params=getattr(trainer, "_lora_enabled", False), diff --git a/tests/post_training/unit/distillation_checkpointing_test.py b/tests/post_training/unit/distillation_checkpointing_test.py index 8d4c8e8ce0..2d6fa26db2 100644 --- a/tests/post_training/unit/distillation_checkpointing_test.py +++ b/tests/post_training/unit/distillation_checkpointing_test.py @@ -86,8 +86,10 @@ def test_save_and_restore_iterator(self): self.assertEqual(iterator.counter, 10) # 2. Save Checkpoint + mock_student_config = mock.Mock() + mock_student_config.learn_to_init_mode = False manager = distillation_utils.MaxTextCheckpointManager( - raw_iterator=iterator, root_directory=self.test_dir, options=self.options + raw_iterator=iterator, root_directory=self.test_dir, student_config=mock_student_config, options=self.options ) # Create dummy model so 'model_params' is not empty @@ -115,8 +117,13 @@ def test_save_and_restore_iterator(self): new_iterator = FakeGrainIterator() self.assertEqual(new_iterator.counter, 0) + mock_student_config_restore = mock.Mock() + mock_student_config_restore.learn_to_init_mode = False restore_manager = distillation_utils.MaxTextCheckpointManager( - raw_iterator=new_iterator, root_directory=self.test_dir, options=self.options + raw_iterator=new_iterator, + root_directory=self.test_dir, + student_config=mock_student_config_restore, + options=self.options, ) with mock.patch.object(jax, "process_index", return_value=0), mock.patch.object(jax, "process_count", return_value=1): diff --git a/tests/post_training/unit/learn-to-init_test.py b/tests/post_training/unit/learn-to-init_test.py new file mode 100644 index 0000000000..42571b8e43 --- /dev/null +++ b/tests/post_training/unit/learn-to-init_test.py @@ -0,0 +1,307 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the LearnToInitDense layer in Learn-To-Init (LTI) distillation.""" + +import unittest +import jax +import jax.numpy as jnp +from flax import nnx +from absl.testing import absltest + +# Import the module under test +from maxtext.layers.learn_to_init_layer import LearnToInitDense +from maxtext.trainers.post_train.distillation.lti_utils import prepare_student_weights +from unittest import mock +from maxtext.models.llama2 import LlamaDecoderLayer +from maxtext.layers.learn_to_init_layer import LearnToInitDecoderLayer + + +# Minimal dummy models for testing +class DummyLayer(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + # A simple linear kernel initialized with random normal distribution + self.kernel = nnx.Param(jax.random.normal(rngs.params(), (4, 4))) + + +class DummyModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.layer1 = DummyLayer(rngs) + self.layer2 = DummyLayer(rngs) + + +class PrepareStudentWeightsTest(unittest.TestCase): + + def test_prepare_student_weights_copy_only(self): + """Verifies that teacher weights are correctly copied into the student model.""" + teacher = DummyModel(nnx.Rngs(0)) + student = DummyModel(nnx.Rngs(1)) + + # Ensure the initialized kernels are initially different + self.assertFalse(jnp.array_equal(teacher.layer1.kernel.value, student.layer1.kernel.value)) + self.assertFalse(jnp.array_equal(teacher.layer2.kernel.value, student.layer2.kernel.value)) + + # Map teacher's layer1 to student's layer1 + copy_map = {"layer1/kernel": "layer1/kernel"} + + prepare_student_weights( + student_model=student, teacher_model=teacher, teacher_weights_copy_map=copy_map, student_weights_share_map={} + ) + + # Verify that layer1 was copied over + self.assertTrue(jnp.array_equal(student.layer1.kernel.value, teacher.layer1.kernel.value)) + + # Verify that layer2 remained untouched + self.assertFalse(jnp.array_equal(student.layer2.kernel.value, teacher.layer2.kernel.value)) + + def test_prepare_student_weights_share_and_copy(self): + """Verifies the behavior when using the experimental weight sharing map.""" + teacher = DummyModel(nnx.Rngs(0)) + student = DummyModel(nnx.Rngs(1)) + + # We share layer1's node to layer2's path in the local dictionary. + # This means any subsequent copy targeted at "layer2/kernel" will actually write + # into the student's layer1 node. + share_map = {"layer1/kernel": "layer2/kernel"} + + # We copy the teacher's layer2 into the student's layer2 path + copy_map = {"layer2/kernel": "layer2/kernel"} + + prepare_student_weights( + student_model=student, + teacher_model=teacher, + teacher_weights_copy_map=copy_map, + student_weights_share_map=share_map, + ) + + # Since student's layer2 was "shared" from layer1, the copy operation + # overwrites student's layer1. + self.assertTrue(jnp.array_equal(student.layer1.kernel.value, teacher.layer2.kernel.value)) + + # The actual layer2 of the student remains unchanged because the dictionary + # reference was rerouted. We verify it still has its original initialization. + student_original_layer2 = DummyModel(nnx.Rngs(1)).layer2.kernel.value + self.assertTrue(jnp.array_equal(student.layer2.kernel.value, student_original_layer2)) + + def test_prepare_student_weights_shape_mismatch(self): + """Verifies that an error is raised when trying to copy misaligned shapes.""" + teacher = DummyModel(nnx.Rngs(0)) + student = DummyModel(nnx.Rngs(1)) + + # Modify student shape manually to force a mismatch + student.layer1.kernel.value = jnp.zeros((8, 8)) + + copy_map = {"layer1/kernel": "layer1/kernel"} + + with self.assertRaisesRegex(AssertionError, "Shape mismatch for layer1/kernel"): + prepare_student_weights( + student_model=student, teacher_model=teacher, teacher_weights_copy_map=copy_map, student_weights_share_map={} + ) + + +class LearnToInitDenseTest(unittest.TestCase): + + def test_qkv_projection_standard_map(self): + """Verifies parameter shapes and forward pass for QKV-like projection (is_output_projection=False).""" + embed_dim = 16 + teacher_heads = 4 + teacher_head_dim = 8 + + # C shape for Q,K,V projections: (embed_dim, teacher_heads, teacher_head_dim) + C = jnp.ones((embed_dim, teacher_heads, teacher_head_dim)) + + student_heads = 2 + student_head_dim = 16 + + layer = LearnToInitDense( + in_features_shape=(embed_dim,), + out_features_shape=(student_heads, student_head_dim), + C=C, + is_output_projection=False, + use_general_linear_map=False, + rngs=nnx.Rngs(0), + ) + + # Verify initialized parameters + # A maps teacher_heads -> student_heads + self.assertEqual(layer.A.value.shape, (teacher_heads, student_heads)) + # B maps teacher_head_dim -> student_head_dim + self.assertEqual(layer.B.value.shape, (teacher_head_dim, student_head_dim)) + self.assertEqual(layer.C.value.shape, (embed_dim, teacher_heads, teacher_head_dim)) + + # Verify forward pass shape + batch_size = 2 + seq_len = 5 + x = jnp.ones((batch_size, seq_len, embed_dim)) + out = layer(x) + self.assertEqual(out.shape, (batch_size, seq_len, student_heads, student_head_dim)) + + def test_out_projection_standard_map(self): + """Verifies parameter shapes and forward pass for Output projection (is_output_projection=True).""" + embed_dim = 16 + teacher_heads = 4 + teacher_head_dim = 8 + + # C shape for Output projection: (teacher_heads, teacher_head_dim, embed_dim) + C = jnp.ones((teacher_heads, teacher_head_dim, embed_dim)) + + student_heads = 2 + student_head_dim = 16 + + layer = LearnToInitDense( + in_features_shape=(student_heads, student_head_dim), + out_features_shape=(embed_dim,), + C=C, + axis=(-2, -1), # Reduce over the student heads and head_dim + is_output_projection=True, + use_general_linear_map=False, + rngs=nnx.Rngs(0), + ) + + # Verify initialized parameters + # A maps teacher_heads -> student_heads + self.assertEqual(layer.A.value.shape, (teacher_heads, student_heads)) + # B maps student_head_dim -> teacher_head_dim + self.assertEqual(layer.B.value.shape, (student_head_dim, teacher_head_dim)) + + # Verify forward pass shape + batch_size = 2 + seq_len = 5 + x = jnp.ones((batch_size, seq_len, student_heads, student_head_dim)) + out = layer(x) + self.assertEqual(out.shape, (batch_size, seq_len, embed_dim)) + + def test_qkv_projection_general_map(self): + """Verifies parameter shapes and forward pass for QKV-like projection with a general map (W).""" + embed_dim = 16 + teacher_heads = 4 + teacher_head_dim = 8 + C = jnp.ones((embed_dim, teacher_heads, teacher_head_dim)) + + student_heads = 2 + student_head_dim = 16 + + layer = LearnToInitDense( + in_features_shape=(embed_dim,), + out_features_shape=(student_heads, student_head_dim), + C=C, + is_output_projection=False, + use_general_linear_map=True, + rngs=nnx.Rngs(0), + ) + + # Verify W tensor shape is correctly formatted as (x, y, u, v) + self.assertEqual(layer.W.value.shape, (teacher_heads, teacher_head_dim, student_heads, student_head_dim)) + + # Verify forward pass shape + batch_size = 2 + seq_len = 5 + x = jnp.ones((batch_size, seq_len, embed_dim)) + out = layer(x) + self.assertEqual(out.shape, (batch_size, seq_len, student_heads, student_head_dim)) + + +class LearnToInitDecoderLayerTest(unittest.TestCase): + + def test_llama_lti_decoder_layer_initialization(self): + """Verifies LearnToInitDecoderLayer initializes and modifies LlamaDecoderLayer correctly.""" + + # 1. Setup mock teacher config + mock_teacher_config = mock.MagicMock() + mock_teacher_config.base_num_query_heads = 4 + mock_teacher_config.base_num_kv_heads = 2 + mock_teacher_config.head_dim = 16 + + # 2. Setup mock student config + mock_config = mock.MagicMock() + mock_config.lti_use_general_linear_map = False + mock_config.teacher_config = mock_teacher_config + + # Add attributes strictly required by LlamaDecoderLayer and Attention sub-layers + mock_config.emb_dim = 64 + mock_config.dtype = jnp.float32 + mock_config.weight_dtype = jnp.float32 + mock_config.shard_mode = "auto" + mock_config.normalization_layer_epsilon = 1e-6 + mock_config.num_query_heads = 4 + mock_config.num_kv_heads = 2 + mock_config.head_dim = 16 + mock_config.max_target_length = 32 + mock_config.max_prefill_predict_length = 32 + mock_config.attention = "dot_product" + mock_config.dropout_rate = 0.0 + mock_config.float32_qk_product = False + mock_config.float32_logits = False + mock_config.prefill_cache_axis_order = "0,1,2,3" + mock_config.ar_cache_axis_order = "0,1,2,3" + mock_config.compute_axis_order = "0,1,2,3" + mock_config.reshape_q = False + mock_config.use_ragged_attention = False + mock_config.ragged_block_size = 16 + mock_config.attn_logits_soft_cap = 0.0 + mock_config.mlp_dim = 128 + mock_config.mlp_activations = ["silu", "linear"] + mock_config.debug_sharding = False + mock_config.record_internal_nn_metrics = False + mock_config.scan_layers = False + mock_config.ici_context_autoregressive_parallelism = 1 + mock_config.fused_qkv = False + + # 3. Dummy Jax sharding mesh and NNX Rngs + mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + rngs = nnx.Rngs(0) + + # Patch utility functions to isolate the test from deeper external dependencies + with ( + mock.patch("maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), + mock.patch("maxtext.layers.quantizations.configure_kv_quant", return_value=None), + ): + + # This effectively initializes LlamaLTIDecoderLayer and implicitly calls _customize_attention_modules + layer = LearnToInitDecoderLayer( + base_layer_cls=LlamaDecoderLayer, + config=mock_config, + model_mode="train", + mesh=mesh, + rngs=rngs, + ) + + # 4. Verify initialization result + self.assertIsInstance(layer.learn_to_init_wrapper, LlamaDecoderLayer) + self.assertEqual(layer.self_attention_module_name, "self_attention") + + # 5. Verify the behavior of _customize_attention_modules + # It should correctly replace query, key, value, and out with LearnToInitDense + attention_module = layer.learn_to_init_wrapper.self_attention + + for proj_name in ["query", "key", "value", "out"]: + child = getattr(attention_module, proj_name) + self.assertIsInstance(child, LearnToInitDense, f"{proj_name} was not swapped to LearnToInitDense") + + # Validate that the dummy Teacher Tensor C is dimensioned correctly + if proj_name == "query": + # (emb_dim, teacher_heads, head_dim) -> (64, 4, 16) + self.assertEqual(child.C.value.shape, (64, 4, 16)) + elif proj_name in ("key", "value"): + # (emb_dim, teacher_kv_heads, head_dim) -> (64, 2, 16) + self.assertEqual(child.C.value.shape, (64, 2, 16)) + elif proj_name == "out": + # (teacher_heads, head_dim, emb_dim) -> (4, 16, 64) + self.assertEqual(child.C.value.shape, (4, 16, 64)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index d472c53629..146df93bfb 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -1056,6 +1056,12 @@ def test_main_offline_mode_skips_teacher_loading( mock_student_cfg.use_sft = False mock_student_cfg.enable_dropout = False + # LTI related attributes + mock_student_cfg.learn_to_init_mode = False + mock_student_cfg.distill_weights_copy_map = {} + mock_student_cfg.distill_student_weights_share_map = {} + mock_student_cfg.get_keys.return_value = {} + # Add scheduling attributes mock_student_cfg.distill_alpha_end = None mock_student_cfg.distill_alpha_schedule = "constant" @@ -1146,6 +1152,12 @@ def test_main_online_mode_loads_teacher( mock_student_cfg.use_sft = False mock_student_cfg.enable_dropout = False + # LTI-attributes + mock_student_cfg.learn_to_init_mode = False + mock_student_cfg.distill_weights_copy_map = {} + mock_student_cfg.distill_student_weights_share_map = {} + mock_student_cfg.get_keys.return_value = {} + # Add scheduling attributes mock_student_cfg.distill_alpha_end = None mock_student_cfg.distill_alpha_schedule = "constant" @@ -1170,7 +1182,8 @@ def test_main_online_mode_loads_teacher( mock_student_model = mock.Mock() mock_teacher_model = mock.Mock() - mock_get_model.side_effect = [mock_student_model, mock_teacher_model] + # The teacher is loaded before the student in online mode + mock_get_model.side_effect = [mock_teacher_model, mock_student_model] mock_build_tokenizer.return_value = mock.Mock(pad_id=0) mock_create_iterator.return_value = (mock.Mock(), mock.Mock())