11"""Tool for performing retrieval augmented generation."""
22
3+ import ast
4+ import logging
35import os
46import sqlite3
5- from typing import Any , Dict , List
7+ from typing import Any , Dict , List , Optional
68
79import models
810from agents import base
1113from rag import vector_db
1214import numpy as np
1315
16+ logger = logging .getLogger (__name__ )
17+
1418
1519# We use a hardcoded character limit for the full code context to avoid
1620# exceeding the model's token limit. While the Gemini API does not provide a
1923# when considering that the prompt sends file content in two fields.
2024_MAX_CONTEXT_LENGTH = 100_000
2125
26+ # Corpus tags supported by the RAG layer. Files whose basename starts with
27+ # "maxtext_" are tagged "maxtext"; everything else falls back to "jax".
28+ _KNOWN_CORPORA = ("jax" , "maxtext" )
29+
30+
31+ def _query_prefix_for_target (target : str ) -> str :
32+ """Returns the human-readable prefix used in component-signature queries."""
33+ if target == "maxtext" :
34+ return "MaxText"
35+ return "JAX Flax"
36+
37+
38+ def _corpus_for_filename (filename : str ) -> str :
39+ """Auto-tags a file by its basename. `maxtext_*.py` -> 'maxtext', else 'jax'."""
40+ base = os .path .basename (filename )
41+ if base .startswith ("maxtext_" ):
42+ return "maxtext"
43+ return "jax"
44+
45+
46+ def _extract_component_signatures (code : str , target : str = "jax" ) -> list [str ]:
47+ """Extracts focused query strings per top-level class/function using AST.
48+
49+ For classes: "{prefix} {ClassName} {base_classes} {method_names} {init_params}"
50+ For functions: "{prefix} {func_name} {param_names}"
51+
52+ The prefix is "JAX Flax" for `target="jax"` and "MaxText" for
53+ `target="maxtext"`.
54+
55+ Args:
56+ code: Python source code to parse.
57+ target: Conversion target — selects the query prefix.
58+
59+ Returns:
60+ A list of query strings, one per top-level component.
61+ """
62+ try :
63+ tree = ast .parse (code )
64+ except SyntaxError :
65+ return []
66+
67+ prefix = _query_prefix_for_target (target )
68+ signatures = []
69+ for node in ast .iter_child_nodes (tree ):
70+ if isinstance (node , ast .ClassDef ):
71+ bases = [
72+ ast .unparse (b ) if hasattr (ast , "unparse" ) else getattr (b , "id" , "" )
73+ for b in node .bases
74+ ]
75+ methods = [
76+ n .name for n in ast .iter_child_nodes (node )
77+ if isinstance (n , (ast .FunctionDef , ast .AsyncFunctionDef ))
78+ ]
79+ init_params = []
80+ for n in ast .iter_child_nodes (node ):
81+ if isinstance (n , ast .FunctionDef ) and n .name == "__init__" :
82+ init_params = [
83+ a .arg for a in n .args .args if a .arg != "self"
84+ ]
85+ break
86+ parts = [prefix , node .name ] + bases + methods + init_params
87+ signatures .append (" " .join (parts ))
88+ elif isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )):
89+ params = [a .arg for a in node .args .args if a .arg != "self" ]
90+ parts = [prefix , node .name ] + params
91+ signatures .append (" " .join (parts ))
92+ return signatures
93+
2294
2395class RAGAgent (base .Agent ):
2496 """Tool for performing retrieval augmented generation."""
@@ -29,6 +101,7 @@ def __init__(
29101 embedding_model_name : models .EmbeddingModel ,
30102 db_path : str = vector_db .RAG_DB_FILE ,
31103 api_key : str | None = None ,
104+ target : str = "jax" ,
32105 ):
33106 """Initializes the agent.
34107
@@ -37,32 +110,59 @@ def __init__(
37110 embedding_model_name: Name of the embedding model to use.
38111 db_path: Path to the RAG SQLite database.
39112 api_key: The API key for Google AI services.
113+ target: Conversion target ("jax" or "maxtext"). Selects which corpus
114+ the agent retrieves from.
40115 """
41116 super ().__init__ (model = model )
42117 self ._db_path = db_path
118+ self ._target = target
43119 self ._embedding_agent = embedding .EmbeddingAgent (
44120 model_name = embedding_model_name .value , api_key = api_key
45121 )
46122 vector_db .create_db (db_path )
47- (
48- self ._ids ,
49- self ._names ,
50- self ._texts ,
51- self ._files ,
52- self ._index ,
53- ) = vector_db .make_embedding_index (db_path )
123+ self ._index_by_corpus : Dict [str , Dict [str , Any ]] = {}
124+ self ._refresh_indexes ()
125+
126+ def _refresh_indexes (self ) -> None :
127+ """Rebuilds per-corpus indexes from the database."""
128+ self ._index_by_corpus = {}
129+ for corpus in _KNOWN_CORPORA :
130+ ids , names , texts , files , index = vector_db .make_embedding_index (
131+ self ._db_path , corpus = corpus
132+ )
133+ self ._index_by_corpus [corpus ] = {
134+ "ids" : ids ,
135+ "names" : names ,
136+ "texts" : texts ,
137+ "files" : files ,
138+ "index" : index ,
139+ }
140+
141+ def _active_corpus (self ) -> Dict [str , Any ]:
142+ """Returns the corpus payload selected by `self._target`, with JAX fallback."""
143+ if self ._target in self ._index_by_corpus :
144+ return self ._index_by_corpus [self ._target ]
145+ return self ._index_by_corpus .get ("jax" , {
146+ "ids" : [], "names" : [], "texts" : [], "files" : [], "index" : None ,
147+ })
54148
55149 def build_from_directory (self , source_path : str ):
56- """Builds RAG database from files in a source directory."""
150+ """Builds RAG database from files in a source directory.
151+
152+ Each file is auto-tagged with a corpus based on its basename — files
153+ starting with `maxtext_` are stored in the 'maxtext' corpus, all others
154+ in 'jax'. Retrieval at query time is filtered by the agent's `target`.
155+ """
57156 for root , _ , files in os .walk (source_path ):
58157 for filename in files :
59158 if filename .endswith (".py" ):
60159 file_path = os .path .join (root , filename )
160+ corpus_tag = _corpus_for_filename (filename )
61161 try :
62162 with open (file_path , "r" , encoding = "utf-8" , errors = "replace" ) as f :
63163 content = f .read ()
64164 doc_name = os .path .relpath (file_path , source_path )
65- print (f"Adding { doc_name } to RAG database..." )
165+ print (f"Adding { doc_name } to RAG database (corpus= { corpus_tag } ) ..." )
66166 description = self .generate (
67167 prompts .CODE_DESCRIPTION ,
68168 {
@@ -78,20 +178,21 @@ def build_from_directory(self, source_path: str):
78178 file = file_path ,
79179 embedding = np .array (embedding_vector ),
80180 db_path = self ._db_path ,
181+ corpus = corpus_tag ,
81182 )
82183 except (OSError , sqlite3 .Error ) as e :
83184 print (f"Skipping { file_path } : { e } " )
84- # Refresh index
85- self ._ids , self ._names , self ._texts , self ._files , self ._index = (
86- vector_db .make_embedding_index (self ._db_path )
87- )
185+ # Refresh per-corpus indexes
186+ self ._refresh_indexes ()
88187 print ("Finished building RAG database." )
89188
90189 def retrieve_context (
91190 self , query : str , top_k : int = 3
92191 ) -> List [Dict [str , Any ]]:
93192 """Retrieves relevant context from the vector DB based on the query.
94193
194+ Filters to the corpus selected by the agent's `target`.
195+
95196 Args:
96197 query: The query string to search for.
97198 top_k: The number of top results to return.
@@ -100,22 +201,83 @@ def retrieve_context(
100201 A list of dictionaries, each containing 'name', 'text', 'file',
101202 and 'distance' for a retrieved document.
102203 """
103- if self ._index is None :
204+ payload = self ._active_corpus ()
205+ index = payload .get ("index" )
206+ if index is None :
104207 return []
105208 query_embedding = self ._embedding_agent .embed (query )
106209 results = vector_db .search_embedding (
107- np .array (query_embedding ), self . _index , self . _texts , top_k = top_k
210+ np .array (query_embedding ), index , payload [ "texts" ] , top_k = top_k
108211 )
212+ names = payload ["names" ]
213+ files = payload ["files" ]
109214 retrieved_context = []
110215 for text , distance , i in results :
111216 retrieved_context .append ({
112- "name" : self . _names [i ],
217+ "name" : names [i ],
113218 "text" : text ,
114- "file" : self . _files [i ],
219+ "file" : files [i ],
115220 "distance" : distance ,
116221 })
117222 return retrieved_context
118223
224+ def retrieve_per_component_context (
225+ self ,
226+ source_code : str ,
227+ top_k_per_component : int = 3 ,
228+ max_total : int = 15 ,
229+ ) -> List [Dict [str , Any ]]:
230+ """Retrieves RAG context using a hybrid full-file + per-component strategy.
231+
232+ Combines broad domain context from the full source code query with
233+ targeted results from per-component queries. This ensures the LLM gets
234+ both the overall architectural patterns AND component-specific examples.
235+
236+ Args:
237+ source_code: The full Python source code to retrieve context for.
238+ top_k_per_component: Number of results per component query.
239+ max_total: Maximum total results to return after deduplication.
240+
241+ Returns:
242+ A deduplicated, distance-sorted list of retrieved documents.
243+ """
244+ signatures = _extract_component_signatures (source_code , self ._target )
245+
246+ # Fall back to single-query if AST parsing yielded nothing
247+ if not signatures :
248+ logger .info ("Per-component extraction failed, falling back to single query" )
249+ return self .retrieve_context (source_code , top_k = max_total )
250+
251+ # Start with full-file query for broad domain context
252+ best_by_file : Dict [str , Dict [str , Any ]] = {}
253+ full_results = self .retrieve_context (source_code , top_k = max_total )
254+ for doc in full_results :
255+ best_by_file [doc ["file" ]] = doc
256+
257+ # If >12 components, batch into groups of 3-4 to cap embedding calls
258+ if len (signatures ) > 12 :
259+ batched = []
260+ for i in range (0 , len (signatures ), 4 ):
261+ batched .append (" " .join (signatures [i :i + 4 ]))
262+ queries = batched
263+ else :
264+ queries = signatures
265+
266+ logger .info ("Per-component RAG: %d queries from %d components (+ full-file)" ,
267+ len (queries ), len (signatures ))
268+
269+ # Add per-component results, keeping best distance per file
270+ for query in queries :
271+ results = self .retrieve_context (query , top_k = top_k_per_component )
272+ for doc in results :
273+ fpath = doc ["file" ]
274+ if fpath not in best_by_file or doc ["distance" ] < best_by_file [fpath ]["distance" ]:
275+ best_by_file [fpath ] = doc
276+
277+ # Sort by distance, truncate to max_total
278+ sorted_docs = sorted (best_by_file .values (), key = lambda d : d ["distance" ])
279+ return sorted_docs [:max_total ]
280+
119281 def run (self , query : str , top_k : int = 3 ) -> List [Dict [str , Any ]]:
120282 """Runs RAG to retrieve context for a query."""
121283 return self .retrieve_context (query , top_k )
0 commit comments