diff --git a/MaxCode/agents/migration/model_conversion_agent.py b/MaxCode/agents/migration/model_conversion_agent.py index e7759e1..1e18846 100644 --- a/MaxCode/agents/migration/model_conversion_agent.py +++ b/MaxCode/agents/migration/model_conversion_agent.py @@ -1,4 +1,4 @@ -"""Agent for converting a model from PyTorch to JAX.""" +"""Agent for converting a model from PyTorch to a JAX-family target.""" import re from typing import Any @@ -16,11 +16,23 @@ def _strip_markdown_formatting(text: str) -> str: code_block_match = _CODE_BLOCK_PATTERN.search(text) if code_block_match: return code_block_match.group(1).strip() + # Handle truncated responses: opening ``` present but closing ``` missing + stripped = text.strip() + if stripped.startswith("```"): + first_nl = stripped.find("\n") + if first_nl != -1: + stripped = stripped[first_nl + 1:] + if stripped.endswith("```"): + stripped = stripped[:-3] + return stripped.strip() + # Strip triple-quote wrappers the LLM may use instead of backticks. + if stripped.startswith('"""') and stripped.endswith('"""'): + return stripped[3:-3].strip() return text class ModelConversionAgent(base.Agent): - """Agent for converting a model from PyTorch to JAX. + """Agent for converting a model from PyTorch to JAX/Flax. This agent specializes in converting PyTorch torch.nn.Module class definitions into idiomatic JAX/Flax equivalents (flax.linen.Module). @@ -28,17 +40,31 @@ class ModelConversionAgent(base.Agent): distinct from general API syntax conversion. """ - def __init__(self, model: Any, rag_agent_instance: rag_agent.RAGAgent): - """Initializes the agent.""" + def __init__( + self, + model: Any, + rag_agent_instance: rag_agent.RAGAgent, + target: str = "jax", + ): + """Initializes the agent. + + Args: + model: The LLM model to use for generation. + rag_agent_instance: RAGAgent for retrieving reference snippets. + target: Conversion target ("jax" by default). MaxText conversions are + handled by `MaxTextConversionAgent` rather than this agent, but the + target is plumbed through for prompt selection symmetry. + """ super().__init__( model=model, agent_domain=utils.AgentDomain.MIGRATION, agent_type=utils.AgentType.MODEL_CONVERSION, ) self._rag_agent = rag_agent_instance + self._target = target def run(self, pytorch_model_code: str) -> str: - """Converts a model from PyTorch to JAX. + """Converts a model from PyTorch to JAX/Flax. Args: pytorch_model_code: The PyTorch model code to convert. @@ -46,16 +72,21 @@ def run(self, pytorch_model_code: str) -> str: Returns: The converted JAX code. """ - rag_context_list = self._rag_agent.retrieve_context( - pytorch_model_code, top_k=7 + rag_context_list = self._rag_agent.retrieve_per_component_context( + pytorch_model_code ) rag_context = "\n\n".join([ f"File: {c['file']}\n```python\n{c['text']}\n```" for c in rag_context_list ]) + prompt_template = prompts.get_prompt( + "MODEL_CONVERSION_PROMPT", self._target + ) + if prompt_template is None: + prompt_template = prompts.MODEL_CONVERSION_PROMPT return _strip_markdown_formatting( self.generate( - prompts.MODEL_CONVERSION_PROMPT, + prompt_template, { "pytorch_model_code": pytorch_model_code, "rag_context": rag_context, diff --git a/MaxCode/agents/migration/primary_agent.py b/MaxCode/agents/migration/primary_agent.py index 5d69906..9fdbe11 100644 --- a/MaxCode/agents/migration/primary_agent.py +++ b/MaxCode/agents/migration/primary_agent.py @@ -1,20 +1,25 @@ """Primary orchestration agent for repository migration.""" +import ast import logging import os import re import subprocess import tempfile +import textwrap from typing import Any, Tuple import models from agents import base from agents import utils +from agents.migration import maxtext_conversion_agent from agents.migration import model_conversion_agent from agents.migration import single_file_agent +from agents.migration import validation_agent from agents.migration.prompts import prompts from rag import rag_agent MAX_DEBUG_ITERATIONS = 10 +logger = logging.getLogger(__name__) def _strip_markdown_formatting(text: str) -> str: @@ -22,30 +27,210 @@ def _strip_markdown_formatting(text: str) -> str: code_block_match = re.search(r"```(?:python)?\n?(.*?)\n?```", text, re.DOTALL) if code_block_match: return code_block_match.group(1).strip() + # Handle truncated responses: opening ``` present but closing ``` missing + stripped = text.strip() + if stripped.startswith("```"): + first_nl = stripped.find("\n") + if first_nl != -1: + stripped = stripped[first_nl + 1:] + if stripped.endswith("```"): + stripped = stripped[:-3] + return stripped.strip() + # Strip triple-quote wrappers the LLM may use instead of backticks. + if stripped.startswith('"""') and stripped.endswith('"""'): + return stripped[3:-3].strip() return text +def _find_missing_components(pytorch_code: str, jax_code: str) -> list[str]: + """Returns names of top-level classes/functions in pytorch_code missing from jax_code.""" + try: + src_tree = ast.parse(pytorch_code) + except SyntaxError: + return [] + try: + out_tree = ast.parse(jax_code) + except SyntaxError: + return [] + + src_names = { + node.name for node in ast.iter_child_nodes(src_tree) + if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) + } + out_names = { + node.name for node in ast.iter_child_nodes(out_tree) + if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) + } + return sorted(src_names - out_names) + + +def _extract_component_source(source_code: str, component_name: str) -> str: + """Extracts the full source text of a top-level class or function.""" + try: + tree = ast.parse(source_code) + except SyntaxError: + return "" + lines = source_code.splitlines(keepends=True) + for node in ast.iter_child_nodes(tree): + if (isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == component_name): + start = node.lineno - 1 # ast is 1-indexed + end = node.end_lineno if node.end_lineno else len(lines) + return "".join(lines[start:end]) + return "" + + +def _is_stub_body(body: list[ast.stmt]) -> bool: + """Checks if a function body is a stub (pass, return None, ..., or docstring+pass).""" + stmts = body + # Strip leading docstring + if stmts and isinstance(stmts[0], ast.Expr) and isinstance(stmts[0].value, (ast.Constant, ast.Str)): + stmts = stmts[1:] + if not stmts: + return True + if len(stmts) == 1: + s = stmts[0] + # pass + if isinstance(s, ast.Pass): + return True + # ... (Ellipsis) + if isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant) and s.value.value is ...: + return True + # return None + if isinstance(s, ast.Return) and (s.value is None or (isinstance(s.value, ast.Constant) and s.value.value is None)): + return True + # raise NotImplementedError(...) + if isinstance(s, ast.Raise) and isinstance(s.exc, ast.Call): + func = s.exc.func + if isinstance(func, ast.Name) and func.id == "NotImplementedError": + return True + return False + + +def _find_stub_implementations(code: str) -> list[dict]: + """Walks AST and returns stub functions/methods. + + Returns: + List of dicts with keys: name, kind ('function' or 'method'), parent_class (or None). + """ + try: + tree = ast.parse(code) + except SyntaxError: + return [] + stubs = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if _is_stub_body(node.body): + stubs.append({"name": node.name, "kind": "function", "parent_class": None}) + elif isinstance(node, ast.ClassDef): + for child in ast.iter_child_nodes(node): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + if _is_stub_body(child.body): + stubs.append({"name": child.name, "kind": "method", "parent_class": node.name}) + return stubs + + +def _find_missing_methods(pytorch_code: str, jax_code: str) -> list[dict]: + """Compares methods within matching classes and returns missing ones. + + Returns: + List of dicts with keys: class_name, method_name. + """ + try: + src_tree = ast.parse(pytorch_code) + out_tree = ast.parse(jax_code) + except SyntaxError: + return [] + + def _class_methods(tree: ast.Module) -> dict[str, set[str]]: + result = {} + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + methods = set() + for child in ast.iter_child_nodes(node): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + methods.add(child.name) + result[node.name] = methods + return result + + src_classes = _class_methods(src_tree) + out_classes = _class_methods(out_tree) + + missing = [] + for cls_name, src_methods in src_classes.items(): + if cls_name in out_classes: + for method in sorted(src_methods - out_classes[cls_name]): + # Skip dunder methods other than __init__ and __call__ + if method.startswith("__") and method.endswith("__") and method not in ("__init__", "__call__"): + continue + missing.append({"class_name": cls_name, "method_name": method}) + return missing + + +def _extract_method_source(code: str, class_name: str, method_name: str) -> str: + """Extracts a method's source from within a class.""" + try: + tree = ast.parse(code) + except SyntaxError: + return "" + lines = code.splitlines(keepends=True) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + for child in ast.iter_child_nodes(node): + if (isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) + and child.name == method_name): + start = child.lineno - 1 + end = child.end_lineno if child.end_lineno else len(lines) + return "".join(lines[start:end]) + return "" + + class PrimaryAgent(base.Agent): """Primary orchestration agent for repository migration.""" - def __init__(self, model: Any, api_key: str | None = None): - """Initializes the agent.""" + def __init__(self, model: Any, api_key: str | None = None, + validate: bool = True, target: str = "jax"): + """Initializes the agent. + + Args: + model: The LLM model to use. + api_key: API key for embedding/auth (passed through to the RAG agent). + validate: Whether to run validation/repair after conversion. + target: Conversion target — "jax" (default) or "maxtext". When + target="maxtext", the model conversion stage is replaced by the + staged `MaxTextConversionAgent` and the per-file utility conversion + is skipped (MaxText supplies its own training/CLI utilities). + """ super().__init__( model=model, agent_domain=utils.AgentDomain.MIGRATION, agent_type=utils.AgentType.PRIMARY, ) + self._model_ref = model + self._validate = validate + self._target = target + self._validation_results: dict[str, dict] = {} + self._merge_result = None # Set when running on a directory + self._maxtext_run_result = None # Set when target=="maxtext" self._rag_agent = rag_agent.RAGAgent( model, embedding_model_name=models.EmbeddingModel.GEMINI_EMBEDDING_001, api_key=api_key, + target=target, ) - self._single_file_agent = single_file_agent.PytorchToJaxSingleFileAgent( - model, self._rag_agent - ) - self._model_conversion_agent = model_conversion_agent.ModelConversionAgent( - model, self._rag_agent + self._single_file_agent = single_file_agent.PytorchSingleFileAgent( + model, self._rag_agent, target=target ) + if target == "maxtext": + self._model_conversion_agent = ( + maxtext_conversion_agent.MaxTextConversionAgent( + model, self._rag_agent + ) + ) + else: + self._model_conversion_agent = model_conversion_agent.ModelConversionAgent( + model, self._rag_agent, target=target + ) def _convert_file(self, pytorch_code: str, file_path: str) -> str: """Routes a file to the appropriate conversion agent.""" @@ -53,6 +238,141 @@ def _convert_file(self, pytorch_code: str, file_path: str) -> str: return self._model_conversion_agent.run(pytorch_code) return self._single_file_agent.run(pytorch_code) + _FILL_PROMPT = textwrap.dedent("""\ + Convert the following PyTorch classes/functions to JAX/Flax. + Return ONLY valid Python code. No markdown, no explanation. + + {rag_section} + ## PyTorch components to convert: + ```python + {components_source} + ``` + """) + + _FILL_STUBS_PROMPT = textwrap.dedent("""\ + The following JAX/Flax code contains stub implementations (functions or + methods with placeholder bodies like `pass`, `return None`, `...`, or + `raise NotImplementedError`). Replace every stub with a complete, correct + implementation based on the original PyTorch source provided below. + + Return the COMPLETE JAX file with all stubs filled in. Do not remove any + existing non-stub code. Return ONLY valid Python code. No markdown, no + explanation. + + ## Original PyTorch source for reference: + ```python + {pytorch_source} + ``` + + ## Current JAX/Flax code (with stubs to fill): + ```python + {jax_code} + ``` + """) + + def _fill_missing_components(self, pytorch_code: str, + jax_code: str) -> str: + """Detects components missing from the JAX output and converts them. + + Also detects stub implementations and missing methods within classes, + and makes a targeted LLM call to replace them with real implementations. + """ + # --- Phase 1: Fill missing top-level components (existing logic) --- + missing = _find_missing_components(pytorch_code, jax_code) + if missing: + logger.info("Missing components detected: %s", missing) + + sources = [] + for name in missing: + src = _extract_component_source(pytorch_code, name) + if src: + sources.append(src) + + if sources: + components_source = "\n\n".join(sources) + rag_section = "" + if self._rag_agent: + query = "JAX Flax conversion " + " ".join(missing) + try: + docs = self._rag_agent.retrieve_context(query, top_k=10) + if docs: + rag_section = "\n## Reference Patterns (from RAG):\n" + for doc in docs: + rag_section += f"\n### {doc.get('name', 'unknown')}\n{doc.get('text', '')}\n" + except Exception: + pass + + prompt = self._FILL_PROMPT.format( + components_source=components_source, + rag_section=rag_section, + ) + response = self.generate(prompt) + converted = _strip_markdown_formatting(response) + if converted and len(converted.strip()) > 20: + jax_code = jax_code.rstrip() + "\n\n" + converted.strip() + "\n" + + # --- Phase 2: Fix stubs and missing methods --- + stubs = _find_stub_implementations(jax_code) + missing_methods = _find_missing_methods(pytorch_code, jax_code) + + if not stubs and not missing_methods: + return jax_code + + # Collect PyTorch source snippets for the problematic components + pytorch_snippets = [] + seen = set() + for stub in stubs: + if stub["parent_class"]: + key = (stub["parent_class"], stub["name"]) + if key not in seen: + seen.add(key) + src = _extract_method_source(pytorch_code, stub["parent_class"], stub["name"]) + if src: + pytorch_snippets.append(f"# {stub['parent_class']}.{stub['name']}\n{src}") + else: + key = (None, stub["name"]) + if key not in seen: + seen.add(key) + src = _extract_component_source(pytorch_code, stub["name"]) + if src: + pytorch_snippets.append(f"# {stub['name']}\n{src}") + + for mm in missing_methods: + key = (mm["class_name"], mm["method_name"]) + if key not in seen: + seen.add(key) + src = _extract_method_source(pytorch_code, mm["class_name"], mm["method_name"]) + if src: + pytorch_snippets.append(f"# {mm['class_name']}.{mm['method_name']}\n{src}") + + if not pytorch_snippets: + return jax_code + + stub_names = [ + f"{s['parent_class']}.{s['name']}" if s["parent_class"] else s["name"] + for s in stubs + ] + mm_names = [f"{m['class_name']}.{m['method_name']}" for m in missing_methods] + logger.info("Stub implementations found: %s", stub_names) + logger.info("Missing methods found: %s", mm_names) + + pytorch_source = "\n\n".join(pytorch_snippets) + prompt = self._FILL_STUBS_PROMPT.format( + pytorch_source=pytorch_source, + jax_code=jax_code, + ) + response = self.generate(prompt) + repaired = _strip_markdown_formatting(response) + + # Only accept if result is a reasonable-length complete file that parses + if repaired and len(repaired.strip()) > len(jax_code) * 0.5: + try: + ast.parse(repaired) + return repaired + except SyntaxError: + logger.warning("Stub-filled code has syntax errors, keeping original") + return jax_code + def _execute_test( self, pytorch_code: str, jax_code: str, test_code: str ) -> Tuple[bool, str]: @@ -82,113 +402,235 @@ def _execute_test( except subprocess.CalledProcessError as e: return False, e.stderr - def run(self, repo_path: str, context: str | None = None) -> dict[str, str]: - """Orchestrates the migration of a repository from PyTorch to JAX. + _MAX_REPAIR_ITERATIONS = 3 + + def _validate_and_repair(self, pytorch_code: str, converted_code: str, + file_path: str) -> str: + """Validates converted code and repairs deviations in a loop. + + Runs up to _MAX_REPAIR_ITERATIONS rounds of validate-then-repair. + Exits early if no deviations remain or if the deviation count does + not decrease (no progress). Args: - repo_path: The path to the repository file or directory. - context: Optional raw context to use instead of RAG retrieval. + pytorch_code: The original PyTorch source code. + converted_code: The converted JAX/Flax code. + file_path: The file path (used as key for storing results). Returns: - A dictionary mapping original file paths to converted JAX code. + The final code (repaired if deviations were found, original otherwise). + """ + validator = validation_agent.ValidationAgent( + self._model_ref, rag_agent_instance=self._rag_agent, + target=self._target, + ) + + current_code = converted_code + prev_count = float("inf") + initial_deviations = None + initial_count = 0 + iteration_history = [] + final_deviations = [] + + for iteration in range(1, self._MAX_REPAIR_ITERATIONS + 1): + deviations = validator.validate(pytorch_code, current_code) + count = len(deviations) + logger.info("Validation of %s (iteration %d): found %d deviations", + file_path, iteration, count) + + # Capture initial state for backward compat + if iteration == 1: + initial_deviations = deviations + initial_count = count + + iteration_history.append({ + "iteration": iteration, + "deviation_count": count, + }) + + # Clean — no deviations remain + if not deviations: + final_deviations = [] + break - Raises: - RuntimeError: If the code conversion and validation fails after - `MAX_DEBUG_ITERATIONS` attempts. + # No progress — deviation count did not decrease + if count >= prev_count: + logger.info("No progress on %s at iteration %d (prev=%d, cur=%d), " + "stopping repair loop", file_path, iteration, + prev_count, count) + final_deviations = deviations + break + + current_code = validator.repair( + current_code, deviations, pytorch_code=pytorch_code + ) + prev_count = count + final_deviations = deviations + else: + # Loop exhausted without break — run one final validation + final_check = validator.validate(pytorch_code, current_code) + final_deviations = final_check + iteration_history.append({ + "iteration": self._MAX_REPAIR_ITERATIONS + 1, + "deviation_count": len(final_check), + }) + logger.info("Final validation of %s: %d remaining deviations", + file_path, len(final_check)) + + result = { + "deviations_found": initial_count, + "deviations": initial_deviations or [], + "remaining_deviations_count": len(final_deviations), + "remaining_deviations": final_deviations, + "iterations": len([h for h in iteration_history + if h["iteration"] <= self._MAX_REPAIR_ITERATIONS]), + "iteration_history": iteration_history, + } + self._validation_results[file_path] = result + return current_code + + def get_validation_results(self) -> dict[str, dict]: + """Returns validation results for all processed files. + + Returns: + A dictionary mapping file paths to their validation results, each + containing deviations_found, deviations, remaining_deviations_count, + and remaining_deviations. + """ + return self._validation_results + + def get_merge_result(self): + """Returns the MergeResult from the last directory run, or None.""" + return self._merge_result + + def get_maxtext_result(self): + """Returns the MaxTextRunResult from the last MaxText run, or None.""" + return self._maxtext_run_result + + def _model_name_from_path(self, repo_path: str) -> str: + """Picks a sensible filename stem for MaxText artifacts.""" + base = os.path.basename(os.path.normpath(repo_path)) + if base.endswith(".py"): + base = base[:-3] + return re.sub(r"[^A-Za-z0-9_]+", "_", base) or "model" + + def _run_maxtext(self, repo_path: str) -> dict[str, str]: + """MaxText path: produce YAML + optional layers + optional ckpt converter. + + Returns an empty dict — MaxText artifacts are persisted via + `get_maxtext_result()` rather than the per-file converted-code map. """ if os.path.isfile(repo_path): with open(repo_path, "r", encoding="utf-8", errors="replace") as f: pytorch_code = f.read() + # Look for a companion utils file (e.g. merged_utils.py beside merged_model.py) + base, ext = os.path.splitext(repo_path) + for utils_candidate in [ + base.replace("_model", "_utils") + ext, + base + "_utils" + ext, + ]: + if utils_candidate != repo_path and os.path.isfile(utils_candidate): + with open(utils_candidate, "r", encoding="utf-8", errors="replace") as f: + pytorch_code += "\n\n" + f.read() + logger.info("MaxText: appended companion utils from %s", utils_candidate) + break + model_name = self._model_name_from_path(repo_path) + elif os.path.isdir(repo_path): + from agents.migration.merge_agent import MergeAgent - if context is None: - rag_context_list = self._rag_agent.retrieve_context( - pytorch_code, top_k=7 - ) - rag_context = "\n\n".join([ - f"File: {c['file']}\n```python\n{c['text']}\n```" - for c in rag_context_list - ]) - else: - rag_context = context + merger = MergeAgent() + merge_result = merger.run(repo_path) + self._merge_result = merge_result + pytorch_code = merge_result.model_code + if merge_result.utility_code: + pytorch_code += "\n\n" + merge_result.utility_code + model_name = self._model_name_from_path(repo_path) + logger.info("MaxText: merged model code from %d files (%d chars)", + len(merge_result.model_files), len(merge_result.model_code)) + else: + return { + repo_path: f"# Error: path {repo_path} is not a file or directory." + } - jax_code = _strip_markdown_formatting( - self.generate( - prompts.MIGRATE_MODULE_TO_JAX_PROMPT, - {"pytorch_code": pytorch_code, "rag_context": rag_context}, - ) + logger.info("Running MaxTextConversionAgent on %s ...", repo_path) + self._maxtext_run_result = self._model_conversion_agent.run( + pytorch_code, model_name=model_name + ) + + # Optionally run validation on the layers file (custom decoder block only). + if (self._validate and self._maxtext_run_result.layers_py + and self._maxtext_run_result.decoder_block == "custom"): + validated = self._validate_and_repair( + pytorch_code, + self._maxtext_run_result.layers_py, + f"MaxText/layers/{model_name}.py", ) + self._maxtext_run_result.layers_py = validated - for i in range(MAX_DEBUG_ITERATIONS): - logging.info("Starting testing iteration %d.", i) - test_code = _strip_markdown_formatting( - self.generate( - prompts.EVALUATE_CODE_PROMPT, - {"pytorch_code": pytorch_code, "jax_code": jax_code}, - ) - ) + return {} - if "NOTESTCASE" in test_code: - print( - "Test generation returned NOTESTCASE, assuming conversion is ok." - ) - return {repo_path: jax_code} + def run(self, repo_path: str) -> dict[str, str]: + """Orchestrates the migration of a repository from PyTorch to JAX. - success, output = self._execute_test(pytorch_code, jax_code, test_code) + Args: + repo_path: The path to the repository file or directory. - if success: - print(f"Validation successful after {i} debugging iterations.") - logging.info( - "Validation successful after %d debugging iterations.", i - ) - return {repo_path: jax_code} - else: - traceback = output - logging.error( - "Validation failed on iteration %d. Traceback:\n%s", i, traceback - ) - logging.info("Starting debug iteration %d.", i + 1) - bug_analysis = self.generate( - prompts.BUG_ANALYSIS_PROMPT, - { - "pytorch_code": pytorch_code, - "jax_code": jax_code, - "test_code": test_code, - "traceback": traceback, - }, - ) - print(f"Bug analysis:\n{bug_analysis}") - logging.info("Bug analysis:\n%s", bug_analysis) - jax_code = _strip_markdown_formatting( - self.generate( - prompts.SELF_DEBUGGING_PROMPT, - { - "pytorch_code": pytorch_code, - "jax_code": jax_code, - "test_code": test_code, - "traceback": traceback, - "bug_analysis": bug_analysis, - "rag_context": rag_context, - }, - ) - ) - print(f"Attempting fix with new JAX code for iteration {i+1}.") + Returns: + A dictionary mapping original file paths to converted code. For + `target="maxtext"` the dict is empty; the artifacts are accessed via + `get_maxtext_result()`. + """ + if self._target == "maxtext": + return self._run_maxtext(repo_path) - raise RuntimeError( - "Failed to convert and validate code after" - f" {MAX_DEBUG_ITERATIONS} iterations." + if os.path.isfile(repo_path): + with open(repo_path, "r", encoding="utf-8", errors="replace") as f: + pytorch_code = f.read() + logger.info("Converting %s ...", repo_path) + converted_code = self._convert_file(pytorch_code, repo_path) + converted_code = self._fill_missing_components( + pytorch_code, converted_code ) + if self._validate: + converted_code = self._validate_and_repair( + pytorch_code, converted_code, repo_path + ) + return {repo_path: converted_code} elif os.path.isdir(repo_path): - graph = utils.build_dependency_graph(repo_path) - ordered_files = utils.topological_sort(graph) - converted_files: dict[str, str] = {} - - for file_rel_path in ordered_files: - file_path = os.path.join(repo_path, file_rel_path) - with open(file_path, "r", encoding="utf-8", errors="replace") as f: - pytorch_code = f.read() - converted_code = self._convert_file(pytorch_code, file_path) - converted_files[file_path] = converted_code - return converted_files + from agents.migration.merge_agent import MergeAgent + + merger = MergeAgent() + merge_result = merger.run(repo_path) + self._merge_result = merge_result + results = {} + + # Convert model code + logger.info("Converting merged model code (%d files, %d chars)...", + len(merge_result.model_files), len(merge_result.model_code)) + model_jax = self._convert_file( + merge_result.model_code, "merged_model.py" + ) + model_jax = self._fill_missing_components( + merge_result.model_code, model_jax + ) + if self._validate: + model_jax = self._validate_and_repair( + merge_result.model_code, model_jax, "merged_model.py" + ) + results["model"] = model_jax + + # Convert utility code (if any) + if merge_result.utility_code: + logger.info("Converting merged utility code (%d files, %d chars)...", + len(merge_result.utility_files), + len(merge_result.utility_code)) + utils_jax = self._single_file_agent.run(merge_result.utility_code) + utils_jax = self._fill_missing_components( + merge_result.utility_code, utils_jax + ) + results["utils"] = utils_jax + + return results else: return { repo_path: f"# Error: path {repo_path} is not a file or directory." diff --git a/MaxCode/agents/migration/repo_agent.py b/MaxCode/agents/migration/repo_agent.py index 0688ef4..abe3667 100644 --- a/MaxCode/agents/migration/repo_agent.py +++ b/MaxCode/agents/migration/repo_agent.py @@ -43,7 +43,7 @@ def run(self, repo_path: str) -> Dict[str, str]: try: with open(file_path, "r") as f: pytorch_code = f.read() - rag_context_list = self._rag_agent.retrieve_context(pytorch_code) + rag_context_list = self._rag_agent.retrieve_per_component_context(pytorch_code) rag_context = "\\n\\n".join([ f"File: {c['file']}\\n```python\\n{c['text']}\\n```" for c in rag_context_list diff --git a/MaxCode/agents/migration/single_file_agent.py b/MaxCode/agents/migration/single_file_agent.py index 7bc991a..13853d0 100644 --- a/MaxCode/agents/migration/single_file_agent.py +++ b/MaxCode/agents/migration/single_file_agent.py @@ -1,4 +1,4 @@ -"""Agent for converting a single file from PyTorch to JAX.""" +"""Agent for converting a single file from PyTorch to a JAX-family target.""" import re from typing import Any @@ -9,24 +9,37 @@ from rag import rag_agent -class PytorchToJaxSingleFileAgent(base.Agent): - """Agent for converting a single file from PyTorch to JAX. +class PytorchSingleFileAgent(base.Agent): + """Agent for converting a single file from PyTorch to a JAX-family target. - This agent performs general-purpose conversion of PyTorch API calls to JAX - API calls within a given file. It is best suited for converting utility - functions, data loading pipelines, and training/evaluation loops. For - converting torch.nn.Module definitions to idiomatic Flax equivalents, - consider using the ModelConversionAgent. + This agent performs general-purpose conversion of PyTorch API calls to the + selected target's API calls within a given file. It is best suited for + converting utility functions, data loading pipelines, and training/eval + loops. For converting torch.nn.Module definitions to idiomatic Flax / + MaxText equivalents, consider using ModelConversionAgent or the + MaxTextConversionAgent. """ - def __init__(self, model: Any, rag_agent_instance: rag_agent.RAGAgent): - """Initializes the agent.""" + def __init__( + self, + model: Any, + rag_agent_instance: rag_agent.RAGAgent, + target: str = "jax", + ): + """Initializes the agent. + + Args: + model: The LLM model to use for generation. + rag_agent_instance: RAGAgent for retrieving reference snippets. + target: Conversion target ("jax" or "maxtext"). Selects the prompt. + """ super().__init__( model=model, agent_domain=utils.AgentDomain.MIGRATION, agent_type=utils.AgentType.PYTORCH_TO_JAX_SINGLE_FILE, ) self._rag_agent = rag_agent_instance + self._target = target def _strip_markdown_formatting(self, text: str) -> str: """Strips markdown and returns only the first python code block.""" @@ -35,24 +48,45 @@ def _strip_markdown_formatting(self, text: str) -> str: ) if code_block_match: return code_block_match.group(1).strip() + # Handle truncated responses: opening ``` present but closing ``` missing + stripped = text.strip() + if stripped.startswith("```"): + first_nl = stripped.find("\n") + if first_nl != -1: + stripped = stripped[first_nl + 1:] + if stripped.endswith("```"): + stripped = stripped[:-3] + return stripped.strip() + # Strip triple-quote wrappers the LLM may use instead of backticks. + if stripped.startswith('"""') and stripped.endswith('"""'): + return stripped[3:-3].strip() return text def run(self, pytorch_code: str) -> str: - """Converts a single file from PyTorch to JAX. + """Converts a single file from PyTorch to the selected target. Args: pytorch_code: The PyTorch code to convert. Returns: - The converted JAX code. + The converted code in the target framework. """ - rag_context_list = self._rag_agent.retrieve_context(pytorch_code, top_k=7) + rag_context_list = self._rag_agent.retrieve_per_component_context(pytorch_code) rag_context = "\n\n".join([ f"File: {c['file']}\n```python\n{c['text']}\n```" for c in rag_context_list ]) + prompt_template = prompts.get_prompt( + "MIGRATE_MODULE_TO_JAX_PROMPT", self._target + ) + if prompt_template is None: + prompt_template = prompts.MIGRATE_MODULE_TO_JAX_PROMPT generated_code = self.generate( - prompts.MIGRATE_MODULE_TO_JAX_PROMPT, + prompt_template, {"pytorch_code": pytorch_code, "rag_context": rag_context}, ) return self._strip_markdown_formatting(generated_code) + + +# Backwards-compatibility alias for one release. +PytorchToJaxSingleFileAgent = PytorchSingleFileAgent diff --git a/MaxCode/tools/migration_tool.py b/MaxCode/tools/migration_tool.py index a44f5a1..69a4440 100644 --- a/MaxCode/tools/migration_tool.py +++ b/MaxCode/tools/migration_tool.py @@ -1,29 +1,12 @@ -"""Migration tool for ADK.""" +"""Migration tool for ADK — thin adapter over interface/api.py.""" -import datetime +import dataclasses import json import logging -import os -import pathlib -import shutil -import models -from agents.migration import primary_agent from google.adk.tools.function_tool import FunctionTool - -MAPPING_FILE_NAME = "mapping.json" -ORIGINAL_SOURCE_DIR_NAME = "original_source" - - -def _write_artifact(output_path: pathlib.Path, code: str) -> None: - """Safely writes code to output_path, creating directories as needed.""" - if output_path.parent: - try: - output_path.parent.mkdir(parents=True) - except FileExistsError: - pass - output_path.write_text(code, encoding="utf-8") +from interface import api def convert_code( @@ -31,23 +14,27 @@ def convert_code( destination: str, api_key: str, model_name: str | None = None, + validate: bool = True, + target: str = "jax", ) -> str: - """Converts PyTorch code to JAX and saves it to the destination. + """Converts PyTorch code to a JAX-family target and saves it to disk. Args: source_path: The path to the Python file or directory to migrate. destination: The directory where the migrated files should be saved. api_key: The Google AI API key to use for migration. model_name: The Gemini model to use for migration. + validate: Whether to run validation and repair after conversion. + target: Conversion target — "jax" (default) or "maxtext". Returns: A JSON string containing the destination paths for subsequent steps. """ logging.info( - "convert_code called with source_path=%s, destination=%s, api_key=%s", + "convert_code called with source_path=%s, destination=%s, target=%s", source_path, destination, - api_key, + target, ) if source_path is None: return json.dumps({"error": "source_path is None"}) @@ -56,90 +43,33 @@ def convert_code( if api_key is None: return json.dumps({"error": "api_key is None"}) - workspace_dir = os.environ.get("BUILD_WORKSPACE_DIRECTORY") - abs_path = source_path - if not os.path.isabs(source_path) and workspace_dir: - abs_path = os.path.join(workspace_dir, source_path) - - logging.info("Attempting to convert %s to JAX...", abs_path) - - model_kwargs = {"api_key": api_key} - if model_name: - model_kwargs["model_name"] = model_name - model = models.GeminiTool(**model_kwargs) - agent = primary_agent.PrimaryAgent(model, api_key=api_key) - results = agent.run(abs_path) - - logging.info("Writing converted files to: %s", destination) - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - dest_path = pathlib.Path(destination) / timestamp - logging.info("Outputting to timestamped directory: %s", dest_path) - p = pathlib.Path(abs_path) - is_dir = p.is_dir() - - # Copy original source to destination for user reference and evaluation - source_copy_dir = dest_path / ORIGINAL_SOURCE_DIR_NAME try: - if is_dir: - shutil.copytree(abs_path, source_copy_dir, dirs_exist_ok=True) - else: - source_copy_dir.mkdir(parents=True, exist_ok=True) - shutil.copy2(abs_path, source_copy_dir / p.name) - except OSError as e: - logging.warning("Failed to copy source files to destination: %s", e) - return json.dumps({ - "error": f"Failed to copy source files to destination: {e}", - }) - - written_files = [] - mapping_log = [] - for file_path, code in results.items(): - if is_dir: - relative_path = pathlib.Path(file_path).relative_to(p) - else: - relative_path = pathlib.Path(file_path).name - output_path = dest_path / relative_path - _write_artifact(output_path, code) - written_files.append(output_path) - mapping_log.append({ - "source_file": file_path, - "generated_file": str(output_path), - "status": "success", - }) - - # Create __init__.py files for all directories containing migrated files. - dirs_in_results = set(f.parent for f in written_files) - init_paths_to_create = set() - for d in dirs_in_results: - current_d = d - while current_d and ( - current_d == dest_path or dest_path in current_d.parents - ): - init_py = current_d / "__init__.py" - init_paths_to_create.add(init_py) - if current_d == dest_path: - break - current_d = current_d.parent - - for init_py in init_paths_to_create: - if not init_py.exists(): - _write_artifact(init_py, "") - - # Ensure original_source is importable by adding __init__.py files - for dirpath, _, _ in os.walk(source_copy_dir): - init_py = pathlib.Path(dirpath) / "__init__.py" - if not init_py.exists(): - _write_artifact(init_py, "") - - mapping_path = dest_path / MAPPING_FILE_NAME - with mapping_path.open("w", encoding="utf-8") as f: - json.dump(mapping_log, f, indent=2) + config = api.ConvertConfig( + source_path=source_path, + destination=destination, + api_key=api_key, + model_name=model_name, + validate=validate, + target=target, + ) + result = api.convert(config) + except Exception as e: + logging.exception("Error in convert_code tool") + return json.dumps({"error": str(e)}) response = { - "dest_path": str(dest_path), - "mapping_path": str(mapping_path), - "original_source_dir": str(source_copy_dir), + "dest_path": result.dest_path, + "mapping_path": result.mapping_path, + "original_source_dir": result.original_source_dir, } + if result.validation_path: + response["validation_path"] = result.validation_path + if result.verification_scorecard_path: + response["verification_scorecard_path"] = result.verification_scorecard_path + response["verification_summary"] = result.verification_summary + if result.maxtext_artifacts is not None: + response["maxtext_artifacts"] = dataclasses.asdict(result.maxtext_artifacts) + return json.dumps(response)