Skip to content

Commit 5c6b247

Browse files
Accelerator Agents Teamcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 903294263
1 parent ae1bf59 commit 5c6b247

1 file changed

Lines changed: 94 additions & 2 deletions

File tree

MaxCode/agents/migration/prompts/prompts.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,41 @@
7070
17. **BatchNorm Momentum**: JAX momentum is the decay factor for old statistics (`x_new = momentum * x_old + (1 - momentum) * x_batch`), but PyTorch uses `1 - decay`. To ensure parity, you MUST set JAX momentum to `1 - pytorch_momentum`.
7171
18. **Data Layout**: Standardize on `NHWC` (Channels Last) for JAX performance, but include necessary `jnp.transpose` operations at input/output boundaries to match PyTorch's `NCHW` oracle outputs.
7272
19. **Activation Tracking**: To facilitate equivalence testing, you MUST instrument the JAX model to capture intermediate activations. For every major layer or block (e.g., after a Conv, Dense, or Attention block), use `self.sow('intermediates', 'name_of_activation', activation_tensor)` to record the output.
73+
20. **Weight Initialization**: Match PyTorch initialization exactly.When the source explicitly calls `nn.init.zeros_` on a layer, use`nn.initializers.zeros_init()`. When the source uses bare `nn.Linear()` with no explicit init, use the Flax default or `nn.initializers.normal(stddev=config.initializer_range)`
74+
- Do not use zeros_init unless the source explicitly initializes to zeros.
75+
RMSNorm (1+w): `nn.initializers.zeros_init()`.
76+
RMSNorm (w): `nn.initializers.ones_init()`.
77+
Check each nn.Parameter in the source and match its init.
78+
21. **Train/Eval Mode**: Flax modules do NOT have a `.train` attribute or `.eval()` / `.train()` methods. NEVER write `model.train = True` or `model.train = False`
79+
- This does nothing in Flax and silently produces incorrect behavior. Instead, pass `deterministic=False` for training and `deterministic=True` for evaluation as an argument to `__call__` / `model.apply()`. All stochastic layers (Dropout, router noise) must check the `deterministic` flag.
80+
22. **Preserve ALL Source Components**: Convert EVERY class, function, and method from the source. Do NOT merge base classes into subclasses, do NOT drop utility classes or metric functions, and do NOT omit `get_config()` or serialization methods. If the source has `ExpertBase` and `FFNExpert`, convert both. If the source has a `MoEMetrics` class, convert it.
81+
23. **Preserve Default Values Exactly**: All default parameter values in the JAX output must match the PyTorch source EXACTLY. Do NOT change any numeric default
82+
- not capacity factors, not dropout rates, not epsilon values, not learning rates, not layer counts.
83+
Even if you believe a different value is "better" or "more stable", use the source value. Changed defaults silently alter model behavior and break reproducibility.
84+
24. **Preserve Exact Reduction Operations**: When the source uses `.mean()`,use `jnp.mean()`. When the source uses `.sum()`, use `jnp.sum()`. NEVER substitute one reduction for another.
85+
`torch.mean(x, dim=N)` maps to `jnp.mean(x, axis=N)`.
86+
`torch.sum(x, dim=N)` maps to `jnp.sum(x, axis=N)`.
87+
The dim/axis integer stays the same.
88+
25. **Preserve Method Placement**: If the source defines a method or attribute on a specific class, keep it on that class in the JAX output. Do NOT relocate methods between classes or replace instance methods with standalone functions unless the JAX idiom requires it.
7389
7490
## CRITICAL: Faithfulness to Source Code
7591
92+
This is a TRANSLATION, not a redesign. The converted code must produce
93+
IDENTICAL behavior to the source for the same inputs and weights.
94+
## CRITICAL: Faithfulness to Source Code
95+
7696
NEVER simplify complex tensor reshaping, reordering, or algorithmic patterns
7797
from the source code. If the PyTorch code uses a specific interleaved weight
7898
layout, chunk-parallel algorithm, or multi-step computation, convert it
7999
faithfully to JAX. The RAG context shows EXAMPLES of similar patterns -- use
80100
them as guidance for JAX idioms, but always follow the ACTUAL source code's
81101
logic and structure.
102+
103+
104+
NEVER "improve" the source by changing default values, adding initializers
105+
that the source does not use, substituting reductions (.sum vs .mean), or
106+
dropping components you consider non-essential (logging, metrics, utility
107+
classes). If the source has it, the output must have it.
82108
"""
83109

84110
MIGRATE_MODULE_TO_JAX_PROMPT = """
@@ -356,7 +382,7 @@
356382
2. If the source has a `fix_query_key_value_ordering` method or groups QKVZ
357383
projections by key heads, convert it FAITHFULLY. Reshape to
358384
[B, T, num_k_heads, ...] and split within each key-head group. Do NOT
359-
replace it with a flat split -- that produces wrong tensors when
385+
replace it with a flat split that produces wrong tensors when
360386
num_k_heads != num_v_heads.
361387
3. If the source has a chunk-parallel delta rule with a for-loop computing a
362388
Neumann series (WY representation), convert it using
@@ -366,7 +392,73 @@
366392
linear attention, implement BOTH modes and dispatch based on sequence length.
367393
5. Implement causal_conv1d as a standalone function with both prefill and
368394
single-step decode paths.
369-
6. **Mandatory Activation Parity**: The JAX model must be structured to allow verification of intermediate results. Use Flax's `sow` mechanism to capture activations for every significant layer, using names that clearly correspond to the PyTorch module's attributes.
395+
6. **Mandatory Activation Parity**: The JAX model must be structured to allow
396+
verification of intermediate results. Use Flax's `sow` mechanism to capture activations for every significant layer, using names that clearly correspond to the PyTorch module's attributes.
397+
7. For causal operations with decode-time state (causal conv1d, linear
398+
attention), implement SEPARATE prefill and decode functions. Do NOT use
399+
a single unified function with conditional branching.
400+
8. ALWAYS include a `@dataclasses.dataclass` Config class at the top of the
401+
output file. Mirror ALL fields from the PyTorch configuration class with
402+
their types and default values. Use `dataclasses.field(default_factory=...)`
403+
for mutable defaults. Use the Config type (not `Any`) in module annotations.
404+
9. The `load_balancing_loss` function MUST accept an optional `attention_mask`
405+
parameter. When the mask is provided, broadcast it to match the concatenated
406+
router logits shape and use it to exclude padding tokens from mean/sum
407+
statistics. See the RAG context for the full pattern.
408+
10. **MoE Experts: Capacity-Based Dispatch (MANDATORY)**. The Experts class MUST
409+
use capacity-based dispatch with dispatch/combine tensors, NOT per-token
410+
gather of expert weights. The correct pattern is:
411+
a) Compute per-expert capacity: `capacity = ceil(T * K / E) * 1.5`
412+
b) Build dispatch tensor via `one_hot(selected_experts) -> cumsum -> positions
413+
-> one_hot(positions, capacity)` to get `dispatch: [T, E, C]`
414+
c) Build combine tensor: `combine = dispatch * routing_weights`
415+
d) Route tokens to expert buffers: `expert_in = einsum('tec,th->ech', dispatch, x)`
416+
e) Batched expert matmul: `expert_out = einsum('ech,ehi->eci', expert_in, W)`
417+
f) Scatter back: `output = einsum('tec,ech->th', combine, expert_out)`
418+
Do NOT use `weight[flat_indices]` gather or `jax.vmap` over individual experts.
419+
Do NOT use `jnp.einsum('td,edh->teh')` computing all experts for all tokens.
420+
The capacity-based approach is 10-50x more efficient for large E (e.g. E=64).
421+
11. **KV Cache: Pure Functional NamedTuple (MANDATORY)**. All KV caches MUST be
422+
NamedTuple objects passed as function arguments and returned as outputs.
423+
Do NOT use Flax mutable variables (`self.variable('cache', ...)`).
424+
Do NOT use config dicts with init flags.
425+
For encoder-decoder models: use SEPARATE self_attn_cache and cross_attn_cache
426+
arguments per layer. Cross-attention caches are populated once from encoder
427+
output and passed through unchanged on subsequent decode steps.
428+
Provide an `init_kv_caches()` helper function that pre-allocates all layer
429+
caches. This replaces PyTorch's `install_kv_cache_hooks()`.
430+
See the RAG context for the full encoder-decoder cache pattern.
431+
12. **Tied Output Projection**: When the PyTorch source computes logits via
432+
`x @ self.token_embedding.weight.T`, convert it to
433+
`(x @ token_embedding.embedding.T).astype(jnp.float32)`.
434+
Do NOT use `token_embedding.attend(x)` -- that is for embedding lookup,
435+
not linear projection, and may produce different results.
436+
13. **Fused QKV Projection**: When the PyTorch source uses a single
437+
`in_proj_weight` of shape [3*embed_dim, embed_dim] with sliced projection
438+
methods (in_proj_qkv, in_proj_q, in_proj_kv), preserve this as a SINGLE
439+
parameter with sliced access in JAX. Do NOT split into 3 separate nn.Dense
440+
layers. Use `self.param('in_proj_weight', init, (3*D, D))` and slice it
441+
for Q [0:D], K [D:2D], V [2D:3D]. Provide in_proj_qkv(), in_proj_q(),
442+
in_proj_kv() methods matching the PyTorch API.
443+
14. **Float32 Softmax Upcast (MANDATORY)**: When the PyTorch source uses
444+
`.float()` or `dtype=torch.float32` before softmax, you MUST preserve this
445+
in JAX: `jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1)` then
446+
cast back with `.astype(q.dtype)`. This is critical for numerical stability
447+
in bfloat16/float16. NEVER omit this upcast.
448+
15. **Preserve ALL Source Components (MANDATORY)**: The output MUST contain a
449+
JAX equivalent for EVERY class, function, method, and utility in the source.
450+
Do NOT merge base classes into subclasses. Do NOT drop get_config() or
451+
serialization methods. Do NOT omit utility classes (e.g., metrics classes)
452+
or standalone functions (e.g., metric computation functions). If the source
453+
has N classes and M functions, the output must have N classes and M functions.
454+
16. **Preserve Default Values Exactly**: All constructor defaults, config
455+
defaults, and hyperparameter defaults MUST match the PyTorch source exactly.
456+
Do NOT change capacity_factor, dropout rates, noise epsilon, num_layers,
457+
or any other default value -- even if you think a different value is better.
458+
17. **Train/Eval Mode in Flax**: NEVER set `model.train = True/False` or call
459+
`model.eval()` / `model.train()` in training loops. Flax has no such
460+
attributes. Use `deterministic=False` for training and `deterministic=True`
461+
for evaluation, passed as an argument to the module's `__call__` method.
370462
371463
Please think step by step about the conversion process before generating the code.
372464
Then, provide the complete JAX equivalent of the entire file above.

0 commit comments

Comments
 (0)