|
70 | 70 | 17. **BatchNorm Momentum**: JAX momentum is the decay factor for old statistics (`x_new = momentum * x_old + (1 - momentum) * x_batch`), but PyTorch uses `1 - decay`. To ensure parity, you MUST set JAX momentum to `1 - pytorch_momentum`. |
71 | 71 | 18. **Data Layout**: Standardize on `NHWC` (Channels Last) for JAX performance, but include necessary `jnp.transpose` operations at input/output boundaries to match PyTorch's `NCHW` oracle outputs. |
72 | 72 | 19. **Activation Tracking**: To facilitate equivalence testing, you MUST instrument the JAX model to capture intermediate activations. For every major layer or block (e.g., after a Conv, Dense, or Attention block), use `self.sow('intermediates', 'name_of_activation', activation_tensor)` to record the output. |
| 73 | +20. **Weight Initialization**: Match PyTorch initialization exactly.When the source explicitly calls `nn.init.zeros_` on a layer, use`nn.initializers.zeros_init()`. When the source uses bare `nn.Linear()` with no explicit init, use the Flax default or `nn.initializers.normal(stddev=config.initializer_range)` |
| 74 | + - Do not use zeros_init unless the source explicitly initializes to zeros. |
| 75 | + RMSNorm (1+w): `nn.initializers.zeros_init()`. |
| 76 | + RMSNorm (w): `nn.initializers.ones_init()`. |
| 77 | + Check each nn.Parameter in the source and match its init. |
| 78 | +21. **Train/Eval Mode**: Flax modules do NOT have a `.train` attribute or `.eval()` / `.train()` methods. NEVER write `model.train = True` or `model.train = False` |
| 79 | + - This does nothing in Flax and silently produces incorrect behavior. Instead, pass `deterministic=False` for training and `deterministic=True` for evaluation as an argument to `__call__` / `model.apply()`. All stochastic layers (Dropout, router noise) must check the `deterministic` flag. |
| 80 | +22. **Preserve ALL Source Components**: Convert EVERY class, function, and method from the source. Do NOT merge base classes into subclasses, do NOT drop utility classes or metric functions, and do NOT omit `get_config()` or serialization methods. If the source has `ExpertBase` and `FFNExpert`, convert both. If the source has a `MoEMetrics` class, convert it. |
| 81 | +23. **Preserve Default Values Exactly**: All default parameter values in the JAX output must match the PyTorch source EXACTLY. Do NOT change any numeric default |
| 82 | + - not capacity factors, not dropout rates, not epsilon values, not learning rates, not layer counts. |
| 83 | + Even if you believe a different value is "better" or "more stable", use the source value. Changed defaults silently alter model behavior and break reproducibility. |
| 84 | +24. **Preserve Exact Reduction Operations**: When the source uses `.mean()`,use `jnp.mean()`. When the source uses `.sum()`, use `jnp.sum()`. NEVER substitute one reduction for another. |
| 85 | + `torch.mean(x, dim=N)` maps to `jnp.mean(x, axis=N)`. |
| 86 | + `torch.sum(x, dim=N)` maps to `jnp.sum(x, axis=N)`. |
| 87 | + The dim/axis integer stays the same. |
| 88 | +25. **Preserve Method Placement**: If the source defines a method or attribute on a specific class, keep it on that class in the JAX output. Do NOT relocate methods between classes or replace instance methods with standalone functions unless the JAX idiom requires it. |
73 | 89 |
|
74 | 90 | ## CRITICAL: Faithfulness to Source Code |
75 | 91 |
|
| 92 | +This is a TRANSLATION, not a redesign. The converted code must produce |
| 93 | +IDENTICAL behavior to the source for the same inputs and weights. |
| 94 | +## CRITICAL: Faithfulness to Source Code |
| 95 | +
|
76 | 96 | NEVER simplify complex tensor reshaping, reordering, or algorithmic patterns |
77 | 97 | from the source code. If the PyTorch code uses a specific interleaved weight |
78 | 98 | layout, chunk-parallel algorithm, or multi-step computation, convert it |
79 | 99 | faithfully to JAX. The RAG context shows EXAMPLES of similar patterns -- use |
80 | 100 | them as guidance for JAX idioms, but always follow the ACTUAL source code's |
81 | 101 | logic and structure. |
| 102 | +
|
| 103 | +
|
| 104 | +NEVER "improve" the source by changing default values, adding initializers |
| 105 | +that the source does not use, substituting reductions (.sum vs .mean), or |
| 106 | +dropping components you consider non-essential (logging, metrics, utility |
| 107 | +classes). If the source has it, the output must have it. |
82 | 108 | """ |
83 | 109 |
|
84 | 110 | MIGRATE_MODULE_TO_JAX_PROMPT = """ |
|
356 | 382 | 2. If the source has a `fix_query_key_value_ordering` method or groups QKVZ |
357 | 383 | projections by key heads, convert it FAITHFULLY. Reshape to |
358 | 384 | [B, T, num_k_heads, ...] and split within each key-head group. Do NOT |
359 | | - replace it with a flat split -- that produces wrong tensors when |
| 385 | + replace it with a flat split that produces wrong tensors when |
360 | 386 | num_k_heads != num_v_heads. |
361 | 387 | 3. If the source has a chunk-parallel delta rule with a for-loop computing a |
362 | 388 | Neumann series (WY representation), convert it using |
|
366 | 392 | linear attention, implement BOTH modes and dispatch based on sequence length. |
367 | 393 | 5. Implement causal_conv1d as a standalone function with both prefill and |
368 | 394 | single-step decode paths. |
369 | | -6. **Mandatory Activation Parity**: The JAX model must be structured to allow verification of intermediate results. Use Flax's `sow` mechanism to capture activations for every significant layer, using names that clearly correspond to the PyTorch module's attributes. |
| 395 | +6. **Mandatory Activation Parity**: The JAX model must be structured to allow |
| 396 | + verification of intermediate results. Use Flax's `sow` mechanism to capture activations for every significant layer, using names that clearly correspond to the PyTorch module's attributes. |
| 397 | +7. For causal operations with decode-time state (causal conv1d, linear |
| 398 | + attention), implement SEPARATE prefill and decode functions. Do NOT use |
| 399 | + a single unified function with conditional branching. |
| 400 | +8. ALWAYS include a `@dataclasses.dataclass` Config class at the top of the |
| 401 | + output file. Mirror ALL fields from the PyTorch configuration class with |
| 402 | + their types and default values. Use `dataclasses.field(default_factory=...)` |
| 403 | + for mutable defaults. Use the Config type (not `Any`) in module annotations. |
| 404 | +9. The `load_balancing_loss` function MUST accept an optional `attention_mask` |
| 405 | + parameter. When the mask is provided, broadcast it to match the concatenated |
| 406 | + router logits shape and use it to exclude padding tokens from mean/sum |
| 407 | + statistics. See the RAG context for the full pattern. |
| 408 | +10. **MoE Experts: Capacity-Based Dispatch (MANDATORY)**. The Experts class MUST |
| 409 | + use capacity-based dispatch with dispatch/combine tensors, NOT per-token |
| 410 | + gather of expert weights. The correct pattern is: |
| 411 | + a) Compute per-expert capacity: `capacity = ceil(T * K / E) * 1.5` |
| 412 | + b) Build dispatch tensor via `one_hot(selected_experts) -> cumsum -> positions |
| 413 | + -> one_hot(positions, capacity)` to get `dispatch: [T, E, C]` |
| 414 | + c) Build combine tensor: `combine = dispatch * routing_weights` |
| 415 | + d) Route tokens to expert buffers: `expert_in = einsum('tec,th->ech', dispatch, x)` |
| 416 | + e) Batched expert matmul: `expert_out = einsum('ech,ehi->eci', expert_in, W)` |
| 417 | + f) Scatter back: `output = einsum('tec,ech->th', combine, expert_out)` |
| 418 | + Do NOT use `weight[flat_indices]` gather or `jax.vmap` over individual experts. |
| 419 | + Do NOT use `jnp.einsum('td,edh->teh')` computing all experts for all tokens. |
| 420 | + The capacity-based approach is 10-50x more efficient for large E (e.g. E=64). |
| 421 | +11. **KV Cache: Pure Functional NamedTuple (MANDATORY)**. All KV caches MUST be |
| 422 | + NamedTuple objects passed as function arguments and returned as outputs. |
| 423 | + Do NOT use Flax mutable variables (`self.variable('cache', ...)`). |
| 424 | + Do NOT use config dicts with init flags. |
| 425 | + For encoder-decoder models: use SEPARATE self_attn_cache and cross_attn_cache |
| 426 | + arguments per layer. Cross-attention caches are populated once from encoder |
| 427 | + output and passed through unchanged on subsequent decode steps. |
| 428 | + Provide an `init_kv_caches()` helper function that pre-allocates all layer |
| 429 | + caches. This replaces PyTorch's `install_kv_cache_hooks()`. |
| 430 | + See the RAG context for the full encoder-decoder cache pattern. |
| 431 | +12. **Tied Output Projection**: When the PyTorch source computes logits via |
| 432 | + `x @ self.token_embedding.weight.T`, convert it to |
| 433 | + `(x @ token_embedding.embedding.T).astype(jnp.float32)`. |
| 434 | + Do NOT use `token_embedding.attend(x)` -- that is for embedding lookup, |
| 435 | + not linear projection, and may produce different results. |
| 436 | +13. **Fused QKV Projection**: When the PyTorch source uses a single |
| 437 | + `in_proj_weight` of shape [3*embed_dim, embed_dim] with sliced projection |
| 438 | + methods (in_proj_qkv, in_proj_q, in_proj_kv), preserve this as a SINGLE |
| 439 | + parameter with sliced access in JAX. Do NOT split into 3 separate nn.Dense |
| 440 | + layers. Use `self.param('in_proj_weight', init, (3*D, D))` and slice it |
| 441 | + for Q [0:D], K [D:2D], V [2D:3D]. Provide in_proj_qkv(), in_proj_q(), |
| 442 | + in_proj_kv() methods matching the PyTorch API. |
| 443 | +14. **Float32 Softmax Upcast (MANDATORY)**: When the PyTorch source uses |
| 444 | + `.float()` or `dtype=torch.float32` before softmax, you MUST preserve this |
| 445 | + in JAX: `jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1)` then |
| 446 | + cast back with `.astype(q.dtype)`. This is critical for numerical stability |
| 447 | + in bfloat16/float16. NEVER omit this upcast. |
| 448 | +15. **Preserve ALL Source Components (MANDATORY)**: The output MUST contain a |
| 449 | + JAX equivalent for EVERY class, function, method, and utility in the source. |
| 450 | + Do NOT merge base classes into subclasses. Do NOT drop get_config() or |
| 451 | + serialization methods. Do NOT omit utility classes (e.g., metrics classes) |
| 452 | + or standalone functions (e.g., metric computation functions). If the source |
| 453 | + has N classes and M functions, the output must have N classes and M functions. |
| 454 | +16. **Preserve Default Values Exactly**: All constructor defaults, config |
| 455 | + defaults, and hyperparameter defaults MUST match the PyTorch source exactly. |
| 456 | + Do NOT change capacity_factor, dropout rates, noise epsilon, num_layers, |
| 457 | + or any other default value -- even if you think a different value is better. |
| 458 | +17. **Train/Eval Mode in Flax**: NEVER set `model.train = True/False` or call |
| 459 | + `model.eval()` / `model.train()` in training loops. Flax has no such |
| 460 | + attributes. Use `deterministic=False` for training and `deterministic=True` |
| 461 | + for evaluation, passed as an argument to the module's `__call__` method. |
370 | 462 |
|
371 | 463 | Please think step by step about the conversion process before generating the code. |
372 | 464 | Then, provide the complete JAX equivalent of the entire file above. |
|
0 commit comments