Skip to content

Commit 53a3c6d

Browse files
gvanicaclaude
andcommitted
Add RAG knowledge base: JAX/Flax reference docs and targeted migration rules
Adds 55 new RAG source documents to improve migration quality: Generic sources (22 files): - JAX/Flax documentation: module API, layers API, setup vs compact, gotchas, lax primitives, attention patterns - Flash Linear Attention (FLA) library references: gated delta net layers/models, l2norm, layernorm gated, rotary, short conv, ops - MaxText reference implementations: attentions, embeddings, linears, normalizations, Qwen3/DeepSeek model layers Targeted sources (30 files): - Migration-specific rules covering: dtype fidelity, causal conv1d, config dataclasses, stop_gradient, mixed precision, KV cache, encoder-decoder cache, Flax checkpoint API, train/eval mode, float32 softmax upcast, fused QKV projection, weight init patterns, class hierarchy preservation, source faithfulness, triangular masking, WY representation, and more Also updates rag_agent.py and vector_db.py to support the expanded corpus with targeted RAG retrieval. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a998c34 commit 53a3c6d

57 files changed

Lines changed: 15592 additions & 24 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

MaxCode/rag/rag_agent.py

Lines changed: 180 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Tool for performing retrieval augmented generation."""
22

3+
import ast
4+
import logging
35
import os
46
import sqlite3
5-
from typing import Any, Dict, List
7+
from typing import Any, Dict, List, Optional
68

79
import models
810
from agents import base
@@ -11,6 +13,8 @@
1113
from rag import vector_db
1214
import numpy as np
1315

16+
logger = logging.getLogger(__name__)
17+
1418

1519
# We use a hardcoded character limit for the full code context to avoid
1620
# exceeding the model's token limit. While the Gemini API does not provide a
@@ -19,6 +23,74 @@
1923
# when considering that the prompt sends file content in two fields.
2024
_MAX_CONTEXT_LENGTH = 100_000
2125

26+
# Corpus tags supported by the RAG layer. Files whose basename starts with
27+
# "maxtext_" are tagged "maxtext"; everything else falls back to "jax".
28+
_KNOWN_CORPORA = ("jax", "maxtext")
29+
30+
31+
def _query_prefix_for_target(target: str) -> str:
32+
"""Returns the human-readable prefix used in component-signature queries."""
33+
if target == "maxtext":
34+
return "MaxText"
35+
return "JAX Flax"
36+
37+
38+
def _corpus_for_filename(filename: str) -> str:
39+
"""Auto-tags a file by its basename. `maxtext_*.py` -> 'maxtext', else 'jax'."""
40+
base = os.path.basename(filename)
41+
if base.startswith("maxtext_"):
42+
return "maxtext"
43+
return "jax"
44+
45+
46+
def _extract_component_signatures(code: str, target: str = "jax") -> list[str]:
47+
"""Extracts focused query strings per top-level class/function using AST.
48+
49+
For classes: "{prefix} {ClassName} {base_classes} {method_names} {init_params}"
50+
For functions: "{prefix} {func_name} {param_names}"
51+
52+
The prefix is "JAX Flax" for `target="jax"` and "MaxText" for
53+
`target="maxtext"`.
54+
55+
Args:
56+
code: Python source code to parse.
57+
target: Conversion target — selects the query prefix.
58+
59+
Returns:
60+
A list of query strings, one per top-level component.
61+
"""
62+
try:
63+
tree = ast.parse(code)
64+
except SyntaxError:
65+
return []
66+
67+
prefix = _query_prefix_for_target(target)
68+
signatures = []
69+
for node in ast.iter_child_nodes(tree):
70+
if isinstance(node, ast.ClassDef):
71+
bases = [
72+
ast.unparse(b) if hasattr(ast, "unparse") else getattr(b, "id", "")
73+
for b in node.bases
74+
]
75+
methods = [
76+
n.name for n in ast.iter_child_nodes(node)
77+
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
78+
]
79+
init_params = []
80+
for n in ast.iter_child_nodes(node):
81+
if isinstance(n, ast.FunctionDef) and n.name == "__init__":
82+
init_params = [
83+
a.arg for a in n.args.args if a.arg != "self"
84+
]
85+
break
86+
parts = [prefix, node.name] + bases + methods + init_params
87+
signatures.append(" ".join(parts))
88+
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
89+
params = [a.arg for a in node.args.args if a.arg != "self"]
90+
parts = [prefix, node.name] + params
91+
signatures.append(" ".join(parts))
92+
return signatures
93+
2294

2395
class RAGAgent(base.Agent):
2496
"""Tool for performing retrieval augmented generation."""
@@ -29,6 +101,7 @@ def __init__(
29101
embedding_model_name: models.EmbeddingModel,
30102
db_path: str = vector_db.RAG_DB_FILE,
31103
api_key: str | None = None,
104+
target: str = "jax",
32105
):
33106
"""Initializes the agent.
34107
@@ -37,32 +110,59 @@ def __init__(
37110
embedding_model_name: Name of the embedding model to use.
38111
db_path: Path to the RAG SQLite database.
39112
api_key: The API key for Google AI services.
113+
target: Conversion target ("jax" or "maxtext"). Selects which corpus
114+
the agent retrieves from.
40115
"""
41116
super().__init__(model=model)
42117
self._db_path = db_path
118+
self._target = target
43119
self._embedding_agent = embedding.EmbeddingAgent(
44120
model_name=embedding_model_name.value, api_key=api_key
45121
)
46122
vector_db.create_db(db_path)
47-
(
48-
self._ids,
49-
self._names,
50-
self._texts,
51-
self._files,
52-
self._index,
53-
) = vector_db.make_embedding_index(db_path)
123+
self._index_by_corpus: Dict[str, Dict[str, Any]] = {}
124+
self._refresh_indexes()
125+
126+
def _refresh_indexes(self) -> None:
127+
"""Rebuilds per-corpus indexes from the database."""
128+
self._index_by_corpus = {}
129+
for corpus in _KNOWN_CORPORA:
130+
ids, names, texts, files, index = vector_db.make_embedding_index(
131+
self._db_path, corpus=corpus
132+
)
133+
self._index_by_corpus[corpus] = {
134+
"ids": ids,
135+
"names": names,
136+
"texts": texts,
137+
"files": files,
138+
"index": index,
139+
}
140+
141+
def _active_corpus(self) -> Dict[str, Any]:
142+
"""Returns the corpus payload selected by `self._target`, with JAX fallback."""
143+
if self._target in self._index_by_corpus:
144+
return self._index_by_corpus[self._target]
145+
return self._index_by_corpus.get("jax", {
146+
"ids": [], "names": [], "texts": [], "files": [], "index": None,
147+
})
54148

55149
def build_from_directory(self, source_path: str):
56-
"""Builds RAG database from files in a source directory."""
150+
"""Builds RAG database from files in a source directory.
151+
152+
Each file is auto-tagged with a corpus based on its basename — files
153+
starting with `maxtext_` are stored in the 'maxtext' corpus, all others
154+
in 'jax'. Retrieval at query time is filtered by the agent's `target`.
155+
"""
57156
for root, _, files in os.walk(source_path):
58157
for filename in files:
59158
if filename.endswith(".py"):
60159
file_path = os.path.join(root, filename)
160+
corpus_tag = _corpus_for_filename(filename)
61161
try:
62162
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
63163
content = f.read()
64164
doc_name = os.path.relpath(file_path, source_path)
65-
print(f"Adding {doc_name} to RAG database...")
165+
print(f"Adding {doc_name} to RAG database (corpus={corpus_tag})...")
66166
description = self.generate(
67167
prompts.CODE_DESCRIPTION,
68168
{
@@ -78,20 +178,21 @@ def build_from_directory(self, source_path: str):
78178
file=file_path,
79179
embedding=np.array(embedding_vector),
80180
db_path=self._db_path,
181+
corpus=corpus_tag,
81182
)
82183
except (OSError, sqlite3.Error) as e:
83184
print(f"Skipping {file_path}: {e}")
84-
# Refresh index
85-
self._ids, self._names, self._texts, self._files, self._index = (
86-
vector_db.make_embedding_index(self._db_path)
87-
)
185+
# Refresh per-corpus indexes
186+
self._refresh_indexes()
88187
print("Finished building RAG database.")
89188

90189
def retrieve_context(
91190
self, query: str, top_k: int = 3
92191
) -> List[Dict[str, Any]]:
93192
"""Retrieves relevant context from the vector DB based on the query.
94193
194+
Filters to the corpus selected by the agent's `target`.
195+
95196
Args:
96197
query: The query string to search for.
97198
top_k: The number of top results to return.
@@ -100,22 +201,83 @@ def retrieve_context(
100201
A list of dictionaries, each containing 'name', 'text', 'file',
101202
and 'distance' for a retrieved document.
102203
"""
103-
if self._index is None:
204+
payload = self._active_corpus()
205+
index = payload.get("index")
206+
if index is None:
104207
return []
105208
query_embedding = self._embedding_agent.embed(query)
106209
results = vector_db.search_embedding(
107-
np.array(query_embedding), self._index, self._texts, top_k=top_k
210+
np.array(query_embedding), index, payload["texts"], top_k=top_k
108211
)
212+
names = payload["names"]
213+
files = payload["files"]
109214
retrieved_context = []
110215
for text, distance, i in results:
111216
retrieved_context.append({
112-
"name": self._names[i],
217+
"name": names[i],
113218
"text": text,
114-
"file": self._files[i],
219+
"file": files[i],
115220
"distance": distance,
116221
})
117222
return retrieved_context
118223

224+
def retrieve_per_component_context(
225+
self,
226+
source_code: str,
227+
top_k_per_component: int = 3,
228+
max_total: int = 15,
229+
) -> List[Dict[str, Any]]:
230+
"""Retrieves RAG context using a hybrid full-file + per-component strategy.
231+
232+
Combines broad domain context from the full source code query with
233+
targeted results from per-component queries. This ensures the LLM gets
234+
both the overall architectural patterns AND component-specific examples.
235+
236+
Args:
237+
source_code: The full Python source code to retrieve context for.
238+
top_k_per_component: Number of results per component query.
239+
max_total: Maximum total results to return after deduplication.
240+
241+
Returns:
242+
A deduplicated, distance-sorted list of retrieved documents.
243+
"""
244+
signatures = _extract_component_signatures(source_code, self._target)
245+
246+
# Fall back to single-query if AST parsing yielded nothing
247+
if not signatures:
248+
logger.info("Per-component extraction failed, falling back to single query")
249+
return self.retrieve_context(source_code, top_k=max_total)
250+
251+
# Start with full-file query for broad domain context
252+
best_by_file: Dict[str, Dict[str, Any]] = {}
253+
full_results = self.retrieve_context(source_code, top_k=max_total)
254+
for doc in full_results:
255+
best_by_file[doc["file"]] = doc
256+
257+
# If >12 components, batch into groups of 3-4 to cap embedding calls
258+
if len(signatures) > 12:
259+
batched = []
260+
for i in range(0, len(signatures), 4):
261+
batched.append(" ".join(signatures[i:i + 4]))
262+
queries = batched
263+
else:
264+
queries = signatures
265+
266+
logger.info("Per-component RAG: %d queries from %d components (+ full-file)",
267+
len(queries), len(signatures))
268+
269+
# Add per-component results, keeping best distance per file
270+
for query in queries:
271+
results = self.retrieve_context(query, top_k=top_k_per_component)
272+
for doc in results:
273+
fpath = doc["file"]
274+
if fpath not in best_by_file or doc["distance"] < best_by_file[fpath]["distance"]:
275+
best_by_file[fpath] = doc
276+
277+
# Sort by distance, truncate to max_total
278+
sorted_docs = sorted(best_by_file.values(), key=lambda d: d["distance"])
279+
return sorted_docs[:max_total]
280+
119281
def run(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
120282
"""Runs RAG to retrieve context for a query."""
121283
return self.retrieve_context(query, top_k)

0 commit comments

Comments
 (0)