|
21 | 21 | import jaxtyping |
22 | 22 | from typing import Any, Callable |
23 | 23 |
|
| 24 | +from flax import nnx |
| 25 | + |
24 | 26 | from maxtext.common.common_types import DecoderBlockType |
25 | 27 | from maxtext.inference.offline_engine import InputData |
26 | 28 | from maxtext.utils import max_logging |
@@ -112,6 +114,48 @@ def compute_log_probs( |
112 | 114 | return token_log_probs, intermediate_outputs |
113 | 115 |
|
114 | 116 |
|
| 117 | +def compute_log_probs_nnx( |
| 118 | + model, |
| 119 | + inputs, |
| 120 | + inputs_position, |
| 121 | + inputs_segmentation, |
| 122 | + completion_segmentation, |
| 123 | + config, |
| 124 | + is_train=False, |
| 125 | +): |
| 126 | + """`compute_log_probs` for the NNX path. |
| 127 | +
|
| 128 | + `model` is an `nnx.Module` (carries its own params + RNG state), so there's |
| 129 | + no `params` arg. Intermediates are pulled off the model after the forward |
| 130 | + via `nnx.state(model, nnx.Intermediate).to_pure_dict()`. |
| 131 | + """ |
| 132 | + logits = model( |
| 133 | + decoder_input_tokens=inputs, |
| 134 | + decoder_positions=inputs_position, |
| 135 | + decoder_segment_ids=inputs_segmentation, |
| 136 | + enable_dropout=(config.enable_dropout if is_train else False), |
| 137 | + ) |
| 138 | + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() |
| 139 | + logits = logits / config.decode_sampling_temperature |
| 140 | + |
| 141 | + targets = inputs[:, 1:] |
| 142 | + shifted_completion_segmentation = jax.lax.dynamic_slice( |
| 143 | + completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1) |
| 144 | + ) |
| 145 | + shifted_completion_segmentation = jnp.pad( |
| 146 | + shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0 |
| 147 | + ) |
| 148 | + mask = shifted_completion_segmentation[..., None] |
| 149 | + mask = jnp.broadcast_to(mask, logits.shape) |
| 150 | + masked_logits = jnp.where(mask, logits, -jnp.inf) |
| 151 | + log_probs = jax.nn.log_softmax(masked_logits, axis=-1) |
| 152 | + log_probs = jnp.where(mask, log_probs, -0.0) |
| 153 | + log_probs = log_probs[:, :-1, :] |
| 154 | + token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0] |
| 155 | + token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1] |
| 156 | + return token_log_probs, intermediate_outputs |
| 157 | + |
| 158 | + |
115 | 159 | def generate_offline_completions(config, tokenizer_model, inference_engine, data): |
116 | 160 | """Generates completions for a batch of prompts using an offline engine. |
117 | 161 |
|
@@ -175,6 +219,30 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data |
175 | 219 | return data |
176 | 220 |
|
177 | 221 |
|
| 222 | +def pathways_reshard_nnx( |
| 223 | + config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model |
| 224 | +): |
| 225 | + """`pathways_reshard` for the NNX path. |
| 226 | +
|
| 227 | + Reshard the policy params onto the inference mesh and push them into the |
| 228 | + inference engine. Requires `scan_layers=True` (no NNX-aware unscan helper yet). |
| 229 | + """ |
| 230 | + if not config.scan_layers: |
| 231 | + raise NotImplementedError( |
| 232 | + "GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False." |
| 233 | + ) |
| 234 | + _, policy_params, _ = nnx.split(policy_state_model, nnx.Param, ...) |
| 235 | + _, source_param_shardings, _ = nnx.split(source_shardings_model, nnx.Param, ...) |
| 236 | + _, dest_param_shardings, _ = nnx.split(destination_shardings_model, nnx.Param, ...) |
| 237 | + del source_param_shardings # already encoded on policy_params |
| 238 | + with ( |
| 239 | + jax.transfer_guard_device_to_host("disallow_explicit"), |
| 240 | + jax.transfer_guard_host_to_device("disallow_explicit"), |
| 241 | + ): |
| 242 | + resharded_params = reshard_pytree(policy_params, dest_param_shardings) |
| 243 | + inference_engine.update_params(resharded_params) |
| 244 | + |
| 245 | + |
178 | 246 | def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings): |
179 | 247 | """Reshards model parameters from training to inference sharding. |
180 | 248 |
|
|
0 commit comments