|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Tensor handling utility functions for checkpoint conversion.""" |
| 16 | + |
| 17 | +from functools import partial |
| 18 | +from typing import Any, Callable, List |
| 19 | +import jax.numpy as np |
| 20 | + |
| 21 | + |
| 22 | +def apply_hook_fns(weight, target_shape, hook_fns): |
| 23 | + """Apply hook functions, essential for to_maxtext and to_huggingface""" |
| 24 | + # If hook is unsepecified, use identity |
| 25 | + if hook_fns is None: |
| 26 | + return weight |
| 27 | + if not isinstance(hook_fns, list): |
| 28 | + hook_fns = [hook_fns] |
| 29 | + # Apply a list of hooks, be careful of order |
| 30 | + for hook_fn in hook_fns: |
| 31 | + weight = hook_fn(weight, target_shape) |
| 32 | + return weight |
| 33 | + |
| 34 | + |
| 35 | +def _build_multi_axis_stacked_tensor( |
| 36 | + hf_source_keys: List[List[str]], |
| 37 | + tensor_getter_fn: Callable[[str], np.ndarray], |
| 38 | + hook_fns: Any, |
| 39 | + target_shape: tuple, |
| 40 | + config, |
| 41 | +) -> np.ndarray: |
| 42 | + """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers). |
| 43 | +
|
| 44 | + This function handles the complex case for scanned MoE layers, producing a tensor |
| 45 | + with the shape (num_experts, num_layers, ...). |
| 46 | +
|
| 47 | + Args: |
| 48 | + hf_source_keys: A nested (2D) list of Hugging Face parameter names. |
| 49 | + Outer list iterates experts, inner list iterates layers. |
| 50 | + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). |
| 51 | + hook_fns: The hook function(s) to apply to each individual weight. |
| 52 | + target_shape: The final shape of the target MaxText tensor. |
| 53 | + config: The MaxText pyconfig object. |
| 54 | +
|
| 55 | + Returns: |
| 56 | + The final, assembled NumPy array for the MaxText parameter. |
| 57 | + """ |
| 58 | + all_expert_tensors = [] |
| 59 | + # The hook function needs the shape of an individual slice, not the full stacked tensor. |
| 60 | + # For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:] |
| 61 | + mt_slice_shape = target_shape[2:] |
| 62 | + |
| 63 | + # Outer loop iterates through experts |
| 64 | + for layer_keys_for_expert in hf_source_keys: |
| 65 | + layer_tensors_for_expert = [] |
| 66 | + # Inner loop iterates through layers for the current expert |
| 67 | + for hf_key_single in layer_keys_for_expert: |
| 68 | + hf_tensor_numpy = tensor_getter_fn(hf_key_single) |
| 69 | + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) |
| 70 | + layer_tensors_for_expert.append(processed_hf_tensor) |
| 71 | + all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0)) |
| 72 | + return np.stack(all_expert_tensors, axis=0) |
| 73 | + |
| 74 | + |
| 75 | +def _build_single_axis_stacked_tensor( |
| 76 | + hf_source_keys: List[str], |
| 77 | + tensor_getter_fn: Callable[[str], np.ndarray], |
| 78 | + hook_fns: Any, |
| 79 | + target_shape: tuple, |
| 80 | + config, |
| 81 | +) -> np.ndarray: |
| 82 | + """Builds a MaxText tensor by stacking HF weights along a single axis. |
| 83 | +
|
| 84 | + This function handles both standard scanned layers (e.g., attention) and |
| 85 | + unscanned MoE layers (which are stacked along the expert axis). |
| 86 | +
|
| 87 | + Args: |
| 88 | + hf_source_keys: A 1D list of Hugging Face parameter names. |
| 89 | + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). |
| 90 | + hook_fns: The hook function(s) to apply to each individual weight. |
| 91 | + target_shape: The final shape of the target MaxText tensor. |
| 92 | + config: The MaxText pyconfig object. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + The final, assembled NumPy array for the MaxText parameter. |
| 96 | + """ |
| 97 | + tensors_to_stack = [] |
| 98 | + |
| 99 | + if config.scan_layers: |
| 100 | + # If it's a standard scanned layer, we use the configured param_scan_axis. |
| 101 | + axis_to_stack = config.param_scan_axis |
| 102 | + else: |
| 103 | + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). |
| 104 | + axis_to_stack = 0 |
| 105 | + |
| 106 | + # The hook function needs the shape of an individual slice, not the full stacked tensor. |
| 107 | + # We calculate it by removing the stacking dimension from the final target shape. |
| 108 | + mt_slice_shape_list = list(target_shape) |
| 109 | + del mt_slice_shape_list[axis_to_stack] |
| 110 | + mt_slice_shape = tuple(mt_slice_shape_list) |
| 111 | + |
| 112 | + for hf_key_single in hf_source_keys: |
| 113 | + hf_tensor_numpy = tensor_getter_fn(hf_key_single) |
| 114 | + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) |
| 115 | + tensors_to_stack.append(processed_hf_tensor) |
| 116 | + |
| 117 | + # Stack all processed tensors along the determined axis. |
| 118 | + return np.stack(tensors_to_stack, axis=axis_to_stack) |
| 119 | + |
| 120 | + |
| 121 | +def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config): |
| 122 | + """Determine the loading function for HF keys. |
| 123 | + HF keys can take four forms: |
| 124 | + Case 1: Unscanned (single string) |
| 125 | + Case 2: Scanned (list of strings) |
| 126 | + Case 3: Unscanned with expert stacking (list of strings) |
| 127 | + Case 4: Scanned with expert stacking (nested list of strings) |
| 128 | + """ |
| 129 | + load_fn = None |
| 130 | + if not isinstance(hf_source_keys_or_key, list): |
| 131 | + # Case 1: Single hf key (str) |
| 132 | + def _loader(getter, key, shape, hook): |
| 133 | + return apply_hook_fns(getter(key), shape, hook) |
| 134 | + |
| 135 | + load_fn = partial( |
| 136 | + _loader, |
| 137 | + tensor_getter, |
| 138 | + hf_source_keys_or_key, |
| 139 | + mt_target_shape_or_shapes, |
| 140 | + hook_fn, |
| 141 | + ) |
| 142 | + # Stacked mapping |
| 143 | + elif not isinstance(hf_source_keys_or_key[0], list): |
| 144 | + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) |
| 145 | + load_fn = partial( |
| 146 | + _build_single_axis_stacked_tensor, |
| 147 | + hf_source_keys_or_key, |
| 148 | + tensor_getter, |
| 149 | + hook_fn, |
| 150 | + mt_target_shape_or_shapes, |
| 151 | + config, |
| 152 | + ) |
| 153 | + else: |
| 154 | + # isinstance(hf_source_keys_or_key[0], list) |
| 155 | + # Case 4: Multi-Axis Stacked hf keys (nested list) |
| 156 | + load_fn = partial( |
| 157 | + _build_multi_axis_stacked_tensor, |
| 158 | + hf_source_keys_or_key, |
| 159 | + tensor_getter, |
| 160 | + hook_fn, |
| 161 | + mt_target_shape_or_shapes, |
| 162 | + config, |
| 163 | + ) |
| 164 | + return load_fn |
0 commit comments