Skip to content

MaxCode: Overhaul migration prompts with JAX best practices and MaxText support#26

Open
gvanica wants to merge 1 commit intomainfrom
split/4-prompts-overhaul
Open

MaxCode: Overhaul migration prompts with JAX best practices and MaxText support#26
gvanica wants to merge 1 commit intomainfrom
split/4-prompts-overhaul

Conversation

@gvanica
Copy link
Copy Markdown
Collaborator

@gvanica gvanica commented Apr 22, 2026

Summary

Rewrites and significantly expands the migration prompt library (prompts.py) from ~30 lines to ~800+ lines. The original prompts were minimal instructions; the new prompts encode detailed JAX/Flax best practices, common pitfalls, and conversion rules distilled from iterative testing on real model conversions.

What changed in prompts.py

New: JAX Best Practices block (14 rules enforced across all prompts)

  1. Use Flax Linen with @nn.compact (not setup() or NNX)
  2. KV cache via jax.lax.dynamic_update_slice (never jnp.concatenate)
  3. Causal conv1d with explicit left-padding (padding="VALID")
  4. Standalone imports only (jax, flax.linen, numpy — no torch)
  5. Static shapes for JIT compatibility
  6. Variable ordering — define before use
  7. Additive attention masking (not multiplicative boolean)
  8. RMS norm via jax.lax.rsqrt
  9. JAX activation function mappings
  10. Rotary embedding implementation pattern
  11. Triangular matrix inversion via solve_triangular
  12. Interleaved weight ordering preservation
  13. Hallucination prevention (no num_feature_axes)
  14. Flax scoping and unique naming in loops

New: Prompt variants

  • PYTORCH_TO_JAX_SINGLE_FILE_PROMPT — updated with best practices
  • PYTORCH_TO_JAX_MULTI_FILE_PROMPT — for merged multi-file conversions
  • MaxText-specific prompts for YAML overlay and custom layer generation

Updated: models.py

  • Minor additions to model configuration support

Files

File Lines changed Description
agents/migration/prompts/prompts.py +434 / -39 Complete prompt rewrite with best practices
models.py +5 / -2 Model config additions

Test plan

  • Verify prompt strings are well-formed (no unmatched braces or format placeholders)
  • Run a single-file conversion and confirm the output follows the new best practices (e.g., uses @nn.compact, no torch imports)
  • Confirm MaxText prompts produce valid YAML overlay structure

Split from #17 — PR 4 of 8

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Apr 22, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Rewrites and expands the migration prompt library with:

- Comprehensive JAX/Flax best practices (KV cache, causal conv1d,
  static shapes, attention masking, rotary embeddings, etc.)
- Single-file and multi-file conversion prompts
- MaxText-specific conversion prompts (YAML overlay, custom layers)
- Hallucination prevention rules
- Flax scoping and naming conventions

Also updates models.py with additional model configuration support.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gvanica gvanica force-pushed the split/4-prompts-overhaul branch from 772f8c5 to 43d21eb Compare April 22, 2026 02:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant