diff --git a/MaxCode/ARCHITECTURE.md b/MaxCode/ARCHITECTURE.md index 297910f..94095b7 100644 --- a/MaxCode/ARCHITECTURE.md +++ b/MaxCode/ARCHITECTURE.md @@ -33,29 +33,50 @@ execute the `migration_agent` and `evaluation_agent`, respectively. ### 4. ADK Tools -`tools/migration_tool.py` and `tools/evaluation_tool.py` define ADK -`FunctionTool`s that wrap specific Python functions for code conversion, -config generation, data generation, and testing. +`tools/migration_tool.py`, `tools/evaluation_tool.py`, and +`tools/verification_tool.py` define ADK `FunctionTool`s that wrap specific +Python functions for code conversion, quality verification, config generation, +data generation, and testing. -### 5. Migration and Validation Logic +### 5. Migration Pipeline + +For **directory inputs**, `PrimaryAgent` uses `MergeAgent` +(`agents/migration/merge_agent.py`) to preprocess the repository before +conversion. The merge step: +- Discovers all nn.Module files and builds an import dependency graph +- Filters infrastructure files (fused kernels, CUDA wrappers, etc.) +- Merges model files into a single file in topological order +- Discovers and merges utility files separately +- Filters infrastructure classes from merged output + +For **single-file inputs**, the existing direct conversion path is used. + +After conversion, `migration_tool.convert_code` automatically runs +`VerificationAgent` (`agents/migration/verification_agent.py`) to produce +a completeness scorecard (AST-based, no LLM). The verification tool is +also available standalone via `tools/verification_tool.py`. + +### 6. ADK Agent Orchestration The `migration_agent` orchestrates the end-to-end migration and validation workflow by calling tools in sequence: -1. **`migration_tool.convert_code`**: Converts PyTorch code to JAX using - `agents.migration.primary_agent.PrimaryAgent`, copies the original source - code, and saves the results to a timestamped output directory. Returns - paths to the migrated code, original code, and mapping file. -2. **`evaluation_tool.generate_model_configs`**: Generates configuration +1. **`migration_tool.convert_code`**: Merges, converts, and verifies + PyTorch code to JAX using `PrimaryAgent` (which delegates to + `MergeAgent` for directories). Copies the original source code and + saves results to a timestamped output directory. +2. **`verification_tool.verify_conversion`** (optional): Standalone + quality verification with completeness and correctness scores. +3. **`evaluation_tool.generate_model_configs`**: Generates configuration files from the original PyTorch code. -3. **`evaluation_tool.generate_oracle_data`**: Generates oracle data +4. **`evaluation_tool.generate_oracle_data`**: Generates oracle data (.pkl files) from the PyTorch code using the generated configurations. -4. **`evaluation_tool.run_equivalence_tests`**: Generates test scripts +5. **`evaluation_tool.run_equivalence_tests`**: Generates test scripts that compare JAX outputs against PyTorch oracle data, and then runs these tests using `subprocess`. The result is a destination directory containing the migrated JAX code, a -`mapping.json` file, and an `evaluation` subdirectory with configurations, -oracle data, and test scripts. +`mapping.json` file, a `verification_scorecard.json`, and an `evaluation` +subdirectory with configurations, oracle data, and test scripts. ## Summary @@ -63,10 +84,22 @@ The overall flow for migration is: ``` Gemini CLI -> mcp_server:primary_agent_server -> adk_agents:migration_agent -> - 1. tools:migration_tool:convert_code (Migration) - 2. tools:evaluation_tool:generate_model_configs (Config Gen) - 3. tools:evaluation_tool:generate_oracle_data (Data Gen) - 4. tools:evaluation_tool:run_equivalence_tests (Test Gen & Run) + 1. tools:migration_tool:convert_code + (Merge -> Convert -> Validate/Repair -> Verify) + 2. tools:verification_tool:verify_conversion (optional, standalone) + 3. tools:evaluation_tool:generate_model_configs (Config Gen) + 4. tools:evaluation_tool:generate_oracle_data (Data Gen) + 5. tools:evaluation_tool:run_equivalence_tests (Test Gen & Run) +``` + +The internal flow within `convert_code` for directory inputs: + +``` +MergeAgent.run(repo_dir) # Preprocessing: discover, filter, merge + -> PrimaryAgent._convert_file() # LLM conversion (model + utils) + -> PrimaryAgent._fill_missing() # Gap-filling pass + -> PrimaryAgent._validate() # Validation + repair loop + -> VerificationAgent.verify() # Quality scorecard ``` ## Agent Structure and Extension @@ -74,8 +107,8 @@ Gemini CLI -> mcp_server:primary_agent_server -> adk_agents:migration_agent -> The project separates agent implementation logic from ADK agent/tool definitions: -* **`agents//`**: Contains agent classes with core implementation logic (e.g., `agents/migration/primary_agent.py`). -* **`tools/`**: Contains ADK `FunctionTool` wrappers that call agent logic or other Python functions (e.g., `tools/migration_tool.py`). +* **`agents//`**: Contains agent classes with core implementation logic (e.g., `agents/migration/primary_agent.py`, `agents/migration/merge_agent.py`, `agents/migration/verification_agent.py`). +* **`tools/`**: Contains ADK `FunctionTool` wrappers that call agent logic or other Python functions (e.g., `tools/migration_tool.py`, `tools/verification_tool.py`). * **`mcp_server/adk_agents.py`**: Defines the ADK agent hierarchy, instructions, and tool mappings. ### How to Add a New Capability diff --git a/MaxCode/README.md b/MaxCode/README.md index b40bd0b..36bbdba 100644 --- a/MaxCode/README.md +++ b/MaxCode/README.md @@ -3,6 +3,26 @@ This extension provides development tools for the MaxCode project, including tools for AI-powered code migration between ML frameworks. +## Quick Start + +Want to try MaxCode without the full Gemini CLI setup? The standalone demo +converts a PyTorch repo to JAX in five steps: + +```bash +cd MaxCode/examples/demo +pip install -r requirements.txt +export GOOGLE_API_KEY= # Windows CMD: set GOOGLE_API_KEY= + +python step1_clone_repo.py # Clone a PyTorch repo from GitHub +python step2_populate_rag.py # Build the RAG reference database +python step3_merge.py # Merge model + utility files (MergeAgent) +python step4_convert.py # Convert to JAX with validation + repair +python step5_verify.py # Verify conversion quality (VerificationAgent) +``` + +See [examples/demo/README.md](examples/demo/README.md) for full setup +instructions and details on what each step does. + ## Prerequisites This extension uses the Google AI API, which requires an API key. You can get @@ -196,6 +216,15 @@ dev-server run_evaluation_workflow --prompt "Run equivalence tests for migration ## Architecture -Agents are organized by domain (e.g., migration, kernel) within the `agents/` -directory. For more details on the project architecture and agent structure, see +The migration pipeline: **Clone -> Index -> Merge -> Convert -> Verify**. + +Key agents in `agents/migration/`: +- **MergeAgent** — Pure-logic preprocessing: file discovery, filtering, import + graph analysis, and merging (no LLM calls). +- **PrimaryAgent** — Orchestrates conversion: routes to model or utility + conversion agents, fills missing components, validates and repairs. +- **VerificationAgent** — Post-processing quality scoring: AST-based + completeness + optional LLM-based correctness. + +For more details on the project architecture and agent structure, see [ARCHITECTURE.md](ARCHITECTURE.md). diff --git a/MaxCode/agents/migration/merge_agent.py b/MaxCode/agents/migration/merge_agent.py new file mode 100644 index 0000000..a62e766 --- /dev/null +++ b/MaxCode/agents/migration/merge_agent.py @@ -0,0 +1,740 @@ +"""Merge agent for combining model and utility files before conversion. + +This is a pure-logic agent (no LLM calls). It encapsulates the file +discovery, filtering, import-graph analysis, and merge logic that was +previously in examples/demo/step3_merge.py. +""" + +import ast +import fnmatch +import os +from collections import deque +from dataclasses import dataclass, field + + +@dataclass +class MergeResult: + """Result of merging a repository's model and utility files.""" + model_code: str # merged model code + model_files: list[str] # files included in model merge + utility_code: str | None # merged utility code (None if no utils found) + utility_files: list[str] # files included in utility merge + excluded_files: list[tuple[str, str]] = field(default_factory=list) # (path, reason) + excluded_classes: list[tuple[str, str]] = field(default_factory=list) # (class_name, reason) + utility_categories: dict[str, str] = field(default_factory=dict) # file -> category + + +# --------------------------------------------------------------------------- +# Infrastructure detection constants +# --------------------------------------------------------------------------- + +_INFRA_PACKAGES = { + "apex", + "transformer_engine", "te", + "deepspeed.pipe", "deepspeed.runtime", +} + +_INFRA_BASES = { + "torch.autograd.Function", + "autograd.Function", + "PipelineModule", + "enum.Enum", + "Enum", +} + + +# --------------------------------------------------------------------------- +# AST helpers +# --------------------------------------------------------------------------- + +def _base_to_str(base_node): + """Convert an AST base-class node to a dotted string.""" + if isinstance(base_node, ast.Name): + return base_node.id + if isinstance(base_node, ast.Attribute): + parts = [] + node = base_node + while isinstance(node, ast.Attribute): + parts.append(node.attr) + node = node.value + if isinstance(node, ast.Name): + parts.append(node.id) + return ".".join(reversed(parts)) + return "" + + +def _is_local_import(line, repo_dir): + """Check if an import line resolves to a file within the repo.""" + stripped = line.strip() + if stripped.startswith("from .") or stripped.startswith("from .."): + return True + if stripped.startswith("from "): + parts = stripped.split() + if len(parts) >= 2: + module = parts[1] + module_path = module.replace(".", os.sep) + if os.path.isfile(os.path.join(repo_dir, module_path + ".py")): + return True + if os.path.isfile(os.path.join(repo_dir, module_path, "__init__.py")): + return True + return False + + +def _fix_empty_blocks(code): + """Insert ``pass`` into blocks left empty after import removal.""" + lines = code.split("\n") + result = [] + block_starters = ( + "if ", "elif ", "else:", "else :", + "try:", "try :", "except:", "except ", + "finally:", "finally :", + "for ", "while ", "with ", "def ", "class ", + ) + i = 0 + while i < len(lines): + result.append(lines[i]) + stripped = lines[i].strip() + if stripped.endswith(":") and any(stripped.startswith(kw) for kw in block_starters): + indent = lines[i][: len(lines[i]) - len(lines[i].lstrip())] + body_indent = indent + " " + j = i + 1 + while j < len(lines) and lines[j].strip() == "": + j += 1 + if j >= len(lines): + result.append(body_indent + "pass") + else: + next_indent = lines[j][: len(lines[j]) - len(lines[j].lstrip())] + next_stripped = lines[j].lstrip() + if len(next_indent) <= len(indent) and next_stripped: + result.append(body_indent + "pass") + i += 1 + return "\n".join(result) + + +def _count_module_classes(code): + """Count nn.Module subclasses in source code.""" + try: + tree = ast.parse(code) + except SyntaxError: + return -1 + count = 0 + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + for base in node.bases: + base_str = _base_to_str(base) + if base_str in ("nn.Module", "Module") or base_str.endswith(".Module"): + count += 1 + break + return count + + +# --------------------------------------------------------------------------- +# Infrastructure detection helpers +# --------------------------------------------------------------------------- + +def detect_infrastructure_imports(file_path): + """Return set of known infrastructure package names imported by *file_path*.""" + try: + with open(file_path, "r", encoding="utf-8-sig", errors="replace") as f: + tree = ast.parse(f.read()) + except SyntaxError: + return set() + + found = set() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + top = alias.name.split(".")[0] + if alias.name in _INFRA_PACKAGES or top in _INFRA_PACKAGES: + found.add(top) + elif isinstance(node, ast.ImportFrom): + if node.module: + top = node.module.split(".")[0] + if node.module in _INFRA_PACKAGES or top in _INFRA_PACKAGES: + found.add(top) + return found + + +def _is_infra_base(base_str): + """Return True if *base_str* is a known infrastructure base class.""" + if base_str in _INFRA_BASES: + return True + if base_str.startswith("te.pytorch.") or base_str.startswith("transformer_engine.pytorch."): + return True + return False + + +def classify_file_classes(file_path): + """Return list of class info dicts for every ClassDef in *file_path*.""" + try: + with open(file_path, "r", encoding="utf-8-sig", errors="replace") as f: + tree = ast.parse(f.read()) + except SyntaxError: + return [] + + classes = [] + for node in ast.iter_child_nodes(tree): + if not isinstance(node, ast.ClassDef): + continue + bases = [_base_to_str(b) for b in node.bases] + is_infra = bool(bases) and all(_is_infra_base(b) for b in bases) + classes.append({"name": node.name, "bases": bases, "is_infra": is_infra}) + return classes + + +def should_exclude_class(node, exclude_patterns): + """Check if a ClassDef *node* should be excluded from the merged output.""" + bases = [_base_to_str(b) for b in node.bases] + + for pat in exclude_patterns: + if fnmatch.fnmatch(node.name, pat): + return True, f"matches exclude pattern '{pat}'" + + for b in bases: + if b in ("torch.autograd.Function", "autograd.Function"): + return True, "autograd.Function subclass" + + if "PipelineModule" in bases: + return True, "PipelineModule subclass" + + for b in bases: + if b.startswith("te.pytorch.") or b.startswith("transformer_engine.pytorch."): + return True, "TransformerEngine wrapper" + + if node.name.endswith("Pipe"): + return True, "pipeline wrapper -- name ends with Pipe" + + for b in bases: + if b in ("enum.Enum", "Enum"): + return True, "enum.Enum subclass" + + return False, "" + + +# --------------------------------------------------------------------------- +# Utility classification +# --------------------------------------------------------------------------- + +def classify_utility_file(file_path, repo_dir): + """Classify a utility file into a category. + + Returns one of: "init_reexport", "cuda_kernel", "torch_autograd", + "torch_utility", "pure_python". + """ + basename = os.path.basename(file_path) + try: + with open(file_path, "r", encoding="utf-8-sig", errors="replace") as f: + code = f.read() + tree = ast.parse(code) + except SyntaxError: + return "pure_python" + + if basename == "__init__.py": + body_types = set(type(n).__name__ for n in ast.iter_child_nodes(tree)) + reexport_types = {"Import", "ImportFrom", "Assign", "Expr"} + if body_types <= reexport_types: + return "init_reexport" + + has_cu_ref = ".cu" in code or ".cpp" in code + has_load_call = False + for node in ast.walk(tree): + if isinstance(node, ast.Call): + func = node.func + if isinstance(func, ast.Name) and func.id in ("load", "load_inline"): + has_load_call = True + elif isinstance(func, ast.Attribute) and func.attr in ("load", "load_inline"): + has_load_call = True + if has_cu_ref and has_load_call: + return "cuda_kernel" + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + for base in node.bases: + base_str = _base_to_str(base) + if base_str in ("torch.autograd.Function", "autograd.Function"): + return "torch_autograd" + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "torch" or alias.name.startswith("torch."): + return "torch_utility" + elif isinstance(node, ast.ImportFrom): + if node.module and (node.module == "torch" or node.module.startswith("torch.")): + return "torch_utility" + + return "pure_python" + + +# --------------------------------------------------------------------------- +# MergeAgent +# --------------------------------------------------------------------------- + +class MergeAgent: + """Merges a repository's model and utility files for conversion. + + This is a pure-logic agent (no LLM calls). It handles: + - Model file discovery (nn.Module detection) + - File-level and class-level filtering + - Import graph construction and topological sorting + - File merging with import deduplication + - Utility file discovery and classification + """ + + @staticmethod + def find_model_files(repo_dir): + """Walk the repo and return paths of files containing nn.Module classes.""" + model_files = [] + for root, _, files in os.walk(repo_dir): + for f in sorted(files): + if not f.endswith(".py"): + continue + full = os.path.join(root, f) + if MergeAgent._is_model_file(full): + model_files.append(full) + return model_files + + @staticmethod + def _is_model_file(file_path): + """Detect if a Python file defines any nn.Module subclass.""" + try: + with open(file_path, "r", encoding="utf-8-sig", errors="replace") as f: + code = f.read() + tree = ast.parse(code) + except SyntaxError: + return False + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + for base in node.bases: + if isinstance(base, ast.Attribute) and base.attr == "Module": + return True + if isinstance(base, ast.Name) and base.id == "Module": + return True + return False + + @staticmethod + def get_local_imports(file_path, repo_dir): + """Parse a file's AST and return resolved paths of local imports.""" + try: + with open(file_path, "r", encoding="utf-8-sig", errors="replace") as f: + code = f.read() + tree = ast.parse(code) + except SyntaxError: + return set() + + resolved = set() + file_dir = os.path.dirname(file_path) + + for node in ast.walk(tree): + if not isinstance(node, ast.ImportFrom): + continue + module = node.module + if module is None: + continue + + module_path = module.replace(".", os.sep) + + if node.level > 0: + base = file_dir + for _ in range(node.level - 1): + base = os.path.dirname(base) + candidates = [ + os.path.join(base, module_path + ".py"), + os.path.join(base, module_path, "__init__.py"), + ] + else: + candidates = [ + os.path.join(repo_dir, module_path + ".py"), + os.path.join(repo_dir, module_path, "__init__.py"), + ] + + for candidate in candidates: + candidate = os.path.normpath(candidate) + if os.path.isfile(candidate): + resolved.add(candidate) + break + + return resolved + + @staticmethod + def build_model_import_graph(model_files, repo_dir): + """Build a directed graph of imports between model files.""" + model_set = set(os.path.normpath(f) for f in model_files) + graph = {} + for f in model_files: + f_norm = os.path.normpath(f) + all_imports = MergeAgent.get_local_imports(f, repo_dir) + graph[f_norm] = {imp for imp in all_imports if imp in model_set} + return graph + + @staticmethod + def find_entry_points(model_files, import_graph): + """Find model files at the top of the dependency tree.""" + imported_by_someone = set() + for deps in import_graph.values(): + imported_by_someone.update(deps) + + entries = [] + for f in model_files: + f_norm = os.path.normpath(f) + has_deps = bool(import_graph.get(f_norm)) + is_imported = f_norm in imported_by_someone + if not is_imported and has_deps: + entries.append(f_norm) + + if not entries: + entries = [os.path.normpath(f) for f in model_files] + + return entries + + @staticmethod + def trace_dependencies(entry_points, import_graph): + """BFS from entry points, then topological sort (DFS post-order).""" + visited = set() + order = [] + + reachable = set() + queue = deque(entry_points) + reachable.update(entry_points) + while queue: + node = queue.popleft() + for dep in import_graph.get(node, set()): + if dep not in reachable: + reachable.add(dep) + queue.append(dep) + + def dfs(node): + if node in visited: + return + visited.add(node) + for dep in import_graph.get(node, set()): + if dep in reachable: + dfs(dep) + order.append(node) + + for ep in sorted(entry_points): + dfs(ep) + + return order + + @staticmethod + def merge_files(file_paths, repo_dir): + """Merge files into a single string with imports de-duplicated. + + Returns the merged code string (no file I/O for output). + """ + import_lines = set() + code_sections = [] + + for full_path in file_paths: + rel = os.path.relpath(full_path, repo_dir) + with open(full_path, "r", encoding="utf-8-sig") as f: + content = f.read() + + section_lines = [] + in_docstring = False + skipping_multiline_import = False + for line in content.split("\n"): + stripped = line.strip() + triple_count = stripped.count('"""') + stripped.count("'''") + if triple_count % 2 == 1: + in_docstring = not in_docstring + if in_docstring or triple_count > 0: + section_lines.append(line) + continue + if skipping_multiline_import: + if ")" in stripped: + skipping_multiline_import = False + continue + if _is_local_import(line, repo_dir): + if "(" in stripped and ")" not in stripped: + skipping_multiline_import = True + continue + if not line[:1].isspace() and ( + stripped.startswith("import ") or stripped.startswith("from ") + ): + import_lines.add(line) + else: + section_lines.append(line) + + code_sections.append( + f"\n# {'=' * 70}\n# From {rel}\n# {'=' * 70}\n" + + "\n".join(section_lines) + ) + + fixed_sections = [] + for section in code_sections: + fixed_sections.append(_fix_empty_blocks(section)) + code_sections = fixed_sections + + header = '"""\nMerged model file - auto-generated by MergeAgent\n' + header += f"Source: {repo_dir}\n" + header += f"Files: {len(file_paths)} files detected\n" + header += '"""\n\n' + + merged = header + "\n".join(sorted(import_lines)) + "\n" + "\n".join(code_sections) + return merged + + @staticmethod + def filter_files(model_files, repo_dir, exclude_paths=None): + """Apply file-level filters to the raw model file list. + + Returns (kept_files, [(removed_path, reason), ...]). + """ + if exclude_paths is None: + exclude_paths = [] + + kept = [] + removed = [] + + for full_path in model_files: + rel = os.path.relpath(full_path, repo_dir).replace("\\", "/") + basename = os.path.basename(full_path) + + excluded = False + for pat in exclude_paths: + if fnmatch.fnmatch(rel, pat): + removed.append((full_path, f"matches exclude pattern '{pat}'")) + excluded = True + break + if excluded: + continue + + if fnmatch.fnmatch(basename, "fused_*.py"): + removed.append((full_path, "fused kernel file")) + continue + + classes = classify_file_classes(full_path) + infra_imports = detect_infrastructure_imports(full_path) + if classes and all(c["is_infra"] for c in classes) and infra_imports: + pkg_names = ", ".join(sorted(infra_imports)) + removed.append((full_path, f"all classes are {pkg_names} wrappers")) + continue + + kept.append(full_path) + + return kept, removed + + @staticmethod + def filter_classes_from_code(code, exclude_patterns=None): + """Remove infrastructure classes from merged source code. + + Returns (filtered_code, [(class_name, reason), ...]). + """ + if exclude_patterns is None: + exclude_patterns = [] + + try: + tree = ast.parse(code) + except SyntaxError: + return code, [] + + lines = code.split("\n") + ranges_to_remove = [] + removed_classes = [] + + top_level_nodes = list(ast.iter_child_nodes(tree)) + for i, node in enumerate(top_level_nodes): + if not isinstance(node, ast.ClassDef): + continue + exclude, reason = should_exclude_class(node, exclude_patterns) + if not exclude: + continue + + start = node.lineno + end = node.end_lineno + + if node.decorator_list: + start = min(d.lineno for d in node.decorator_list) + + next_start = None + for j in range(i + 1, len(top_level_nodes)): + nxt = top_level_nodes[j] + if hasattr(nxt, "lineno"): + next_start = nxt.lineno + break + if next_start is not None: + while end + 1 < next_start and lines[end].strip() == "": + end += 1 + + ranges_to_remove.append((start, end)) + removed_classes.append((node.name, reason)) + + if not ranges_to_remove: + return code, [] + + remove_set = set() + for start, end in ranges_to_remove: + for ln in range(start - 1, end): + remove_set.add(ln) + + filtered_lines = [line for idx, line in enumerate(lines) if idx not in remove_set] + return "\n".join(filtered_lines), removed_classes + + @staticmethod + def find_all_local_dependencies(model_files, repo_dir): + """BFS from model files through ALL local imports. + + Returns the set of utility files (non-model files that are + transitively imported by model files). + """ + model_set = set(os.path.normpath(f) for f in model_files) + visited = set(model_set) + queue = deque(model_set) + + while queue: + current = queue.popleft() + for dep in MergeAgent.get_local_imports(current, repo_dir): + dep_norm = os.path.normpath(dep) + if dep_norm not in visited: + visited.add(dep_norm) + queue.append(dep_norm) + + return visited - model_set + + @staticmethod + def filter_utility_files(utility_files, repo_dir, exclude_patterns=None): + """Apply exclusion patterns and classification to utility files. + + Returns (kept, removed_with_reasons, category_map). + """ + if exclude_patterns is None: + exclude_patterns = [] + + kept = [] + removed = [] + category_map = {} + + for full_path in utility_files: + rel = os.path.relpath(full_path, repo_dir).replace("\\", "/") + + excluded = False + for pat in exclude_patterns: + if fnmatch.fnmatch(rel, pat) or fnmatch.fnmatch(os.path.basename(full_path), pat): + removed.append((full_path, f"matches exclude pattern '{pat}'")) + excluded = True + break + if excluded: + continue + + category = classify_utility_file(full_path, repo_dir) + category_map[full_path] = category + + if category == "init_reexport": + removed.append((full_path, "re-export __init__.py (inlined by merge)")) + elif category == "cuda_kernel": + removed.append((full_path, "CUDA kernel loader (no JAX equivalent)")) + else: + kept.append(full_path) + + return kept, removed, category_map + + @staticmethod + def order_utility_files(utility_files, repo_dir): + """Topologically sort utility files by their import dependencies.""" + file_set = set(os.path.normpath(f) for f in utility_files) + graph = {} + for f in utility_files: + f_norm = os.path.normpath(f) + all_imports = MergeAgent.get_local_imports(f, repo_dir) + graph[f_norm] = {imp for imp in all_imports if imp in file_set} + + visited = set() + order = [] + + def dfs(node): + if node in visited: + return + visited.add(node) + for dep in graph.get(node, set()): + dfs(dep) + order.append(node) + + for f in sorted(file_set): + dfs(f) + + return order + + def run(self, repo_dir, exclude_paths=None, exclude_classes=None, + exclude_utils=None): + """Run the full merge pipeline on a repository directory. + + Args: + repo_dir: Path to the repository root. + exclude_paths: Glob patterns for files to exclude from merge. + exclude_classes: Class name patterns to exclude from merged output. + exclude_utils: Glob patterns for utility files to exclude. + + Returns: + MergeResult with merged model code, utility code, and metadata. + """ + if exclude_paths is None: + exclude_paths = [] + if exclude_classes is None: + exclude_classes = [] + if exclude_utils is None: + exclude_utils = [] + + all_excluded_files = [] + all_excluded_classes = [] + + # 1. Find model files + model_files = self.find_model_files(repo_dir) + + # 2. File-level filtering + model_files, removed_files = self.filter_files( + model_files, repo_dir, exclude_paths + ) + all_excluded_files.extend(removed_files) + + # 3. Build import graph and trace dependencies + graph = self.build_model_import_graph(model_files, repo_dir) + entries = self.find_entry_points(model_files, graph) + required = self.trace_dependencies(entries, graph) + + # Track files excluded by graph analysis + required_set = set(required) + for f in model_files: + f_norm = os.path.normpath(f) + if f_norm not in required_set: + all_excluded_files.append( + (f, "not imported by any entry-point model file") + ) + + # 4. Merge model files + model_code = self.merge_files(required, repo_dir) + + # 5. Class-level filtering + model_code, removed_classes = self.filter_classes_from_code( + model_code, exclude_classes + ) + all_excluded_classes.extend(removed_classes) + + # 6. Discover and merge utility files + utility_code = None + utility_files_kept = [] + utility_categories = {} + + util_files = self.find_all_local_dependencies(required, repo_dir) + if util_files: + kept_utils, removed_utils, cat_map = self.filter_utility_files( + sorted(util_files), repo_dir, exclude_utils + ) + all_excluded_files.extend(removed_utils) + utility_categories = cat_map + + if kept_utils: + ordered_utils = self.order_utility_files(kept_utils, repo_dir) + utility_code = self.merge_files(ordered_utils, repo_dir) + utility_files_kept = ordered_utils + + return MergeResult( + model_code=model_code, + model_files=required, + utility_code=utility_code, + utility_files=utility_files_kept, + excluded_files=all_excluded_files, + excluded_classes=all_excluded_classes, + utility_categories=utility_categories, + ) diff --git a/MaxCode/agents/migration/model_conversion_agent.py b/MaxCode/agents/migration/model_conversion_agent.py index e7759e1..92977bc 100644 --- a/MaxCode/agents/migration/model_conversion_agent.py +++ b/MaxCode/agents/migration/model_conversion_agent.py @@ -16,6 +16,18 @@ def _strip_markdown_formatting(text: str) -> str: code_block_match = _CODE_BLOCK_PATTERN.search(text) if code_block_match: return code_block_match.group(1).strip() + # Handle truncated responses: opening ``` present but closing ``` missing + stripped = text.strip() + if stripped.startswith("```"): + first_nl = stripped.find("\n") + if first_nl != -1: + stripped = stripped[first_nl + 1:] + if stripped.endswith("```"): + stripped = stripped[:-3] + return stripped.strip() + # Strip triple-quote wrappers the LLM may use instead of backticks. + if stripped.startswith('"""') and stripped.endswith('"""'): + return stripped[3:-3].strip() return text @@ -46,8 +58,8 @@ def run(self, pytorch_model_code: str) -> str: Returns: The converted JAX code. """ - rag_context_list = self._rag_agent.retrieve_context( - pytorch_model_code, top_k=7 + rag_context_list = self._rag_agent.retrieve_per_component_context( + pytorch_model_code ) rag_context = "\n\n".join([ f"File: {c['file']}\n```python\n{c['text']}\n```" diff --git a/MaxCode/agents/migration/primary_agent.py b/MaxCode/agents/migration/primary_agent.py index 5d69906..2a61631 100644 --- a/MaxCode/agents/migration/primary_agent.py +++ b/MaxCode/agents/migration/primary_agent.py @@ -1,9 +1,11 @@ """Primary orchestration agent for repository migration.""" +import ast import logging import os import re import subprocess import tempfile +import textwrap from typing import Any, Tuple import models @@ -11,10 +13,12 @@ from agents import utils from agents.migration import model_conversion_agent from agents.migration import single_file_agent +from agents.migration import validation_agent from agents.migration.prompts import prompts from rag import rag_agent MAX_DEBUG_ITERATIONS = 10 +logger = logging.getLogger(__name__) def _strip_markdown_formatting(text: str) -> str: @@ -22,19 +26,179 @@ def _strip_markdown_formatting(text: str) -> str: code_block_match = re.search(r"```(?:python)?\n?(.*?)\n?```", text, re.DOTALL) if code_block_match: return code_block_match.group(1).strip() + # Handle truncated responses: opening ``` present but closing ``` missing + stripped = text.strip() + if stripped.startswith("```"): + first_nl = stripped.find("\n") + if first_nl != -1: + stripped = stripped[first_nl + 1:] + if stripped.endswith("```"): + stripped = stripped[:-3] + return stripped.strip() + # Strip triple-quote wrappers the LLM may use instead of backticks. + if stripped.startswith('"""') and stripped.endswith('"""'): + return stripped[3:-3].strip() return text +def _find_missing_components(pytorch_code: str, jax_code: str) -> list[str]: + """Returns names of top-level classes/functions in pytorch_code missing from jax_code.""" + try: + src_tree = ast.parse(pytorch_code) + except SyntaxError: + return [] + try: + out_tree = ast.parse(jax_code) + except SyntaxError: + return [] + + src_names = { + node.name for node in ast.iter_child_nodes(src_tree) + if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) + } + out_names = { + node.name for node in ast.iter_child_nodes(out_tree) + if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) + } + return sorted(src_names - out_names) + + +def _extract_component_source(source_code: str, component_name: str) -> str: + """Extracts the full source text of a top-level class or function.""" + try: + tree = ast.parse(source_code) + except SyntaxError: + return "" + lines = source_code.splitlines(keepends=True) + for node in ast.iter_child_nodes(tree): + if (isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == component_name): + start = node.lineno - 1 # ast is 1-indexed + end = node.end_lineno if node.end_lineno else len(lines) + return "".join(lines[start:end]) + return "" + + +def _is_stub_body(body: list[ast.stmt]) -> bool: + """Checks if a function body is a stub (pass, return None, ..., or docstring+pass).""" + stmts = body + # Strip leading docstring + if stmts and isinstance(stmts[0], ast.Expr) and isinstance(stmts[0].value, (ast.Constant, ast.Str)): + stmts = stmts[1:] + if not stmts: + return True + if len(stmts) == 1: + s = stmts[0] + # pass + if isinstance(s, ast.Pass): + return True + # ... (Ellipsis) + if isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant) and s.value.value is ...: + return True + # return None + if isinstance(s, ast.Return) and (s.value is None or (isinstance(s.value, ast.Constant) and s.value.value is None)): + return True + # raise NotImplementedError(...) + if isinstance(s, ast.Raise) and isinstance(s.exc, ast.Call): + func = s.exc.func + if isinstance(func, ast.Name) and func.id == "NotImplementedError": + return True + return False + + +def _find_stub_implementations(code: str) -> list[dict]: + """Walks AST and returns stub functions/methods. + + Returns: + List of dicts with keys: name, kind ('function' or 'method'), parent_class (or None). + """ + try: + tree = ast.parse(code) + except SyntaxError: + return [] + stubs = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if _is_stub_body(node.body): + stubs.append({"name": node.name, "kind": "function", "parent_class": None}) + elif isinstance(node, ast.ClassDef): + for child in ast.iter_child_nodes(node): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + if _is_stub_body(child.body): + stubs.append({"name": child.name, "kind": "method", "parent_class": node.name}) + return stubs + + +def _find_missing_methods(pytorch_code: str, jax_code: str) -> list[dict]: + """Compares methods within matching classes and returns missing ones. + + Returns: + List of dicts with keys: class_name, method_name. + """ + try: + src_tree = ast.parse(pytorch_code) + out_tree = ast.parse(jax_code) + except SyntaxError: + return [] + + def _class_methods(tree: ast.Module) -> dict[str, set[str]]: + result = {} + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + methods = set() + for child in ast.iter_child_nodes(node): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + methods.add(child.name) + result[node.name] = methods + return result + + src_classes = _class_methods(src_tree) + out_classes = _class_methods(out_tree) + + missing = [] + for cls_name, src_methods in src_classes.items(): + if cls_name in out_classes: + for method in sorted(src_methods - out_classes[cls_name]): + # Skip dunder methods other than __init__ and __call__ + if method.startswith("__") and method.endswith("__") and method not in ("__init__", "__call__"): + continue + missing.append({"class_name": cls_name, "method_name": method}) + return missing + + +def _extract_method_source(code: str, class_name: str, method_name: str) -> str: + """Extracts a method's source from within a class.""" + try: + tree = ast.parse(code) + except SyntaxError: + return "" + lines = code.splitlines(keepends=True) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + for child in ast.iter_child_nodes(node): + if (isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) + and child.name == method_name): + start = child.lineno - 1 + end = child.end_lineno if child.end_lineno else len(lines) + return "".join(lines[start:end]) + return "" + + class PrimaryAgent(base.Agent): """Primary orchestration agent for repository migration.""" - def __init__(self, model: Any, api_key: str | None = None): + def __init__(self, model: Any, api_key: str | None = None, + validate: bool = True): """Initializes the agent.""" super().__init__( model=model, agent_domain=utils.AgentDomain.MIGRATION, agent_type=utils.AgentType.PRIMARY, ) + self._model_ref = model + self._validate = validate + self._validation_results: dict[str, dict] = {} + self._merge_result = None # Set when running on a directory self._rag_agent = rag_agent.RAGAgent( model, embedding_model_name=models.EmbeddingModel.GEMINI_EMBEDDING_001, @@ -53,6 +217,141 @@ def _convert_file(self, pytorch_code: str, file_path: str) -> str: return self._model_conversion_agent.run(pytorch_code) return self._single_file_agent.run(pytorch_code) + _FILL_PROMPT = textwrap.dedent("""\ + Convert the following PyTorch classes/functions to JAX/Flax. + Return ONLY valid Python code. No markdown, no explanation. + + {rag_section} + ## PyTorch components to convert: + ```python + {components_source} + ``` + """) + + _FILL_STUBS_PROMPT = textwrap.dedent("""\ + The following JAX/Flax code contains stub implementations (functions or + methods with placeholder bodies like `pass`, `return None`, `...`, or + `raise NotImplementedError`). Replace every stub with a complete, correct + implementation based on the original PyTorch source provided below. + + Return the COMPLETE JAX file with all stubs filled in. Do not remove any + existing non-stub code. Return ONLY valid Python code. No markdown, no + explanation. + + ## Original PyTorch source for reference: + ```python + {pytorch_source} + ``` + + ## Current JAX/Flax code (with stubs to fill): + ```python + {jax_code} + ``` + """) + + def _fill_missing_components(self, pytorch_code: str, + jax_code: str) -> str: + """Detects components missing from the JAX output and converts them. + + Also detects stub implementations and missing methods within classes, + and makes a targeted LLM call to replace them with real implementations. + """ + # --- Phase 1: Fill missing top-level components (existing logic) --- + missing = _find_missing_components(pytorch_code, jax_code) + if missing: + logger.info("Missing components detected: %s", missing) + + sources = [] + for name in missing: + src = _extract_component_source(pytorch_code, name) + if src: + sources.append(src) + + if sources: + components_source = "\n\n".join(sources) + rag_section = "" + if self._rag_agent: + query = "JAX Flax conversion " + " ".join(missing) + try: + docs = self._rag_agent.retrieve_context(query, top_k=10) + if docs: + rag_section = "\n## Reference Patterns (from RAG):\n" + for doc in docs: + rag_section += f"\n### {doc.get('name', 'unknown')}\n{doc.get('text', '')}\n" + except Exception: + pass + + prompt = self._FILL_PROMPT.format( + components_source=components_source, + rag_section=rag_section, + ) + response = self.generate(prompt) + converted = _strip_markdown_formatting(response) + if converted and len(converted.strip()) > 20: + jax_code = jax_code.rstrip() + "\n\n" + converted.strip() + "\n" + + # --- Phase 2: Fix stubs and missing methods --- + stubs = _find_stub_implementations(jax_code) + missing_methods = _find_missing_methods(pytorch_code, jax_code) + + if not stubs and not missing_methods: + return jax_code + + # Collect PyTorch source snippets for the problematic components + pytorch_snippets = [] + seen = set() + for stub in stubs: + if stub["parent_class"]: + key = (stub["parent_class"], stub["name"]) + if key not in seen: + seen.add(key) + src = _extract_method_source(pytorch_code, stub["parent_class"], stub["name"]) + if src: + pytorch_snippets.append(f"# {stub['parent_class']}.{stub['name']}\n{src}") + else: + key = (None, stub["name"]) + if key not in seen: + seen.add(key) + src = _extract_component_source(pytorch_code, stub["name"]) + if src: + pytorch_snippets.append(f"# {stub['name']}\n{src}") + + for mm in missing_methods: + key = (mm["class_name"], mm["method_name"]) + if key not in seen: + seen.add(key) + src = _extract_method_source(pytorch_code, mm["class_name"], mm["method_name"]) + if src: + pytorch_snippets.append(f"# {mm['class_name']}.{mm['method_name']}\n{src}") + + if not pytorch_snippets: + return jax_code + + stub_names = [ + f"{s['parent_class']}.{s['name']}" if s["parent_class"] else s["name"] + for s in stubs + ] + mm_names = [f"{m['class_name']}.{m['method_name']}" for m in missing_methods] + logger.info("Stub implementations found: %s", stub_names) + logger.info("Missing methods found: %s", mm_names) + + pytorch_source = "\n\n".join(pytorch_snippets) + prompt = self._FILL_STUBS_PROMPT.format( + pytorch_source=pytorch_source, + jax_code=jax_code, + ) + response = self.generate(prompt) + repaired = _strip_markdown_formatting(response) + + # Only accept if result is a reasonable-length complete file that parses + if repaired and len(repaired.strip()) > len(jax_code) * 0.5: + try: + ast.parse(repaired) + return repaired + except SyntaxError: + logger.warning("Stub-filled code has syntax errors, keeping original") + return jax_code + def _execute_test( self, pytorch_code: str, jax_code: str, test_code: str ) -> Tuple[bool, str]: @@ -82,113 +381,163 @@ def _execute_test( except subprocess.CalledProcessError as e: return False, e.stderr - def run(self, repo_path: str, context: str | None = None) -> dict[str, str]: + _MAX_REPAIR_ITERATIONS = 3 + + def _validate_and_repair(self, pytorch_code: str, converted_code: str, + file_path: str) -> str: + """Validates converted code and repairs deviations in a loop. + + Runs up to _MAX_REPAIR_ITERATIONS rounds of validate-then-repair. + Exits early if no deviations remain or if the deviation count does + not decrease (no progress). + + Args: + pytorch_code: The original PyTorch source code. + converted_code: The converted JAX/Flax code. + file_path: The file path (used as key for storing results). + + Returns: + The final code (repaired if deviations were found, original otherwise). + """ + validator = validation_agent.ValidationAgent( + self._model_ref, rag_agent_instance=self._rag_agent + ) + + current_code = converted_code + prev_count = float("inf") + initial_deviations = None + initial_count = 0 + iteration_history = [] + final_deviations = [] + + for iteration in range(1, self._MAX_REPAIR_ITERATIONS + 1): + deviations = validator.validate(pytorch_code, current_code) + count = len(deviations) + logger.info("Validation of %s (iteration %d): found %d deviations", + file_path, iteration, count) + + # Capture initial state for backward compat + if iteration == 1: + initial_deviations = deviations + initial_count = count + + iteration_history.append({ + "iteration": iteration, + "deviation_count": count, + }) + + # Clean — no deviations remain + if not deviations: + final_deviations = [] + break + + # No progress — deviation count did not decrease + if count >= prev_count: + logger.info("No progress on %s at iteration %d (prev=%d, cur=%d), " + "stopping repair loop", file_path, iteration, + prev_count, count) + final_deviations = deviations + break + + current_code = validator.repair( + current_code, deviations, pytorch_code=pytorch_code + ) + prev_count = count + final_deviations = deviations + else: + # Loop exhausted without break — run one final validation + final_check = validator.validate(pytorch_code, current_code) + final_deviations = final_check + iteration_history.append({ + "iteration": self._MAX_REPAIR_ITERATIONS + 1, + "deviation_count": len(final_check), + }) + logger.info("Final validation of %s: %d remaining deviations", + file_path, len(final_check)) + + result = { + "deviations_found": initial_count, + "deviations": initial_deviations or [], + "remaining_deviations_count": len(final_deviations), + "remaining_deviations": final_deviations, + "iterations": len([h for h in iteration_history + if h["iteration"] <= self._MAX_REPAIR_ITERATIONS]), + "iteration_history": iteration_history, + } + self._validation_results[file_path] = result + return current_code + + def get_validation_results(self) -> dict[str, dict]: + """Returns validation results for all processed files. + + Returns: + A dictionary mapping file paths to their validation results, each + containing deviations_found, deviations, remaining_deviations_count, + and remaining_deviations. + """ + return self._validation_results + + def get_merge_result(self): + """Returns the MergeResult from the last directory run, or None.""" + return self._merge_result + + def run(self, repo_path: str) -> dict[str, str]: """Orchestrates the migration of a repository from PyTorch to JAX. Args: repo_path: The path to the repository file or directory. - context: Optional raw context to use instead of RAG retrieval. Returns: A dictionary mapping original file paths to converted JAX code. - - Raises: - RuntimeError: If the code conversion and validation fails after - `MAX_DEBUG_ITERATIONS` attempts. """ if os.path.isfile(repo_path): with open(repo_path, "r", encoding="utf-8", errors="replace") as f: pytorch_code = f.read() - - if context is None: - rag_context_list = self._rag_agent.retrieve_context( - pytorch_code, top_k=7 + logger.info("Converting %s ...", repo_path) + converted_code = self._convert_file(pytorch_code, repo_path) + converted_code = self._fill_missing_components( + pytorch_code, converted_code + ) + if self._validate: + converted_code = self._validate_and_repair( + pytorch_code, converted_code, repo_path ) - rag_context = "\n\n".join([ - f"File: {c['file']}\n```python\n{c['text']}\n```" - for c in rag_context_list - ]) - else: - rag_context = context + return {repo_path: converted_code} + elif os.path.isdir(repo_path): + from agents.migration.merge_agent import MergeAgent + + merger = MergeAgent() + merge_result = merger.run(repo_path) + self._merge_result = merge_result + results = {} - jax_code = _strip_markdown_formatting( - self.generate( - prompts.MIGRATE_MODULE_TO_JAX_PROMPT, - {"pytorch_code": pytorch_code, "rag_context": rag_context}, - ) + # Convert model code + logger.info("Converting merged model code (%d files, %d chars)...", + len(merge_result.model_files), len(merge_result.model_code)) + model_jax = self._convert_file( + merge_result.model_code, "merged_model.py" ) + model_jax = self._fill_missing_components( + merge_result.model_code, model_jax + ) + if self._validate: + model_jax = self._validate_and_repair( + merge_result.model_code, model_jax, "merged_model.py" + ) + results["model"] = model_jax - for i in range(MAX_DEBUG_ITERATIONS): - logging.info("Starting testing iteration %d.", i) - test_code = _strip_markdown_formatting( - self.generate( - prompts.EVALUATE_CODE_PROMPT, - {"pytorch_code": pytorch_code, "jax_code": jax_code}, - ) + # Convert utility code (if any) + if merge_result.utility_code: + logger.info("Converting merged utility code (%d files, %d chars)...", + len(merge_result.utility_files), + len(merge_result.utility_code)) + utils_jax = self._single_file_agent.run(merge_result.utility_code) + utils_jax = self._fill_missing_components( + merge_result.utility_code, utils_jax ) + results["utils"] = utils_jax - if "NOTESTCASE" in test_code: - print( - "Test generation returned NOTESTCASE, assuming conversion is ok." - ) - return {repo_path: jax_code} - - success, output = self._execute_test(pytorch_code, jax_code, test_code) - - if success: - print(f"Validation successful after {i} debugging iterations.") - logging.info( - "Validation successful after %d debugging iterations.", i - ) - return {repo_path: jax_code} - else: - traceback = output - logging.error( - "Validation failed on iteration %d. Traceback:\n%s", i, traceback - ) - logging.info("Starting debug iteration %d.", i + 1) - bug_analysis = self.generate( - prompts.BUG_ANALYSIS_PROMPT, - { - "pytorch_code": pytorch_code, - "jax_code": jax_code, - "test_code": test_code, - "traceback": traceback, - }, - ) - print(f"Bug analysis:\n{bug_analysis}") - logging.info("Bug analysis:\n%s", bug_analysis) - jax_code = _strip_markdown_formatting( - self.generate( - prompts.SELF_DEBUGGING_PROMPT, - { - "pytorch_code": pytorch_code, - "jax_code": jax_code, - "test_code": test_code, - "traceback": traceback, - "bug_analysis": bug_analysis, - "rag_context": rag_context, - }, - ) - ) - print(f"Attempting fix with new JAX code for iteration {i+1}.") - - raise RuntimeError( - "Failed to convert and validate code after" - f" {MAX_DEBUG_ITERATIONS} iterations." - ) - elif os.path.isdir(repo_path): - graph = utils.build_dependency_graph(repo_path) - ordered_files = utils.topological_sort(graph) - converted_files: dict[str, str] = {} - - for file_rel_path in ordered_files: - file_path = os.path.join(repo_path, file_rel_path) - with open(file_path, "r", encoding="utf-8", errors="replace") as f: - pytorch_code = f.read() - converted_code = self._convert_file(pytorch_code, file_path) - converted_files[file_path] = converted_code - return converted_files + return results else: return { repo_path: f"# Error: path {repo_path} is not a file or directory." diff --git a/MaxCode/agents/migration/prompts/prompts.py b/MaxCode/agents/migration/prompts/prompts.py index d5b32dc..a8f6933 100644 --- a/MaxCode/agents/migration/prompts/prompts.py +++ b/MaxCode/agents/migration/prompts/prompts.py @@ -69,15 +69,59 @@ - Every layer explicitly sets `use_bias=True` or `use_bias=False` to exactly match the PyTorch layer. 17. **BatchNorm Momentum**: JAX momentum is the decay factor for old statistics (`x_new = momentum * x_old + (1 - momentum) * x_batch`), but PyTorch uses `1 - decay`. To ensure parity, you MUST set JAX momentum to `1 - pytorch_momentum`. 18. **Data Layout**: Standardize on `NHWC` (Channels Last) for JAX performance, but include necessary `jnp.transpose` operations at input/output boundaries to match PyTorch's `NCHW` oracle outputs. +19. **Weight Initialization**: Match PyTorch initialization exactly. + When the source explicitly calls `nn.init.zeros_` on a layer, use + `nn.initializers.zeros_init()`. When the source uses bare `nn.Linear()` + with no explicit init, use the Flax default (lecun_normal) or + `nn.initializers.normal(stddev=config.initializer_range)` -- do NOT use + zeros_init unless the source explicitly initializes to zeros. + RMSNorm (1+w): `nn.initializers.zeros_init()`. + RMSNorm (w): `nn.initializers.ones_init()`. + Check each nn.Parameter in the source and match its init. +20. **Train/Eval Mode**: Flax modules do NOT have a `.train` attribute or + `.eval()` / `.train()` methods. NEVER write `model.train = True` or + `model.train = False` -- this does nothing in Flax and silently produces + incorrect behavior. Instead, pass `deterministic=False` for training and + `deterministic=True` for evaluation as an argument to `__call__` / + `model.apply()`. All stochastic layers (Dropout, router noise) must + check the `deterministic` flag. +21. **Preserve ALL Source Components**: Convert EVERY class, function, and + method from the source. Do NOT merge base classes into subclasses, do NOT + drop utility classes or metric functions, and do NOT omit `get_config()` + or serialization methods. If the source has `ExpertBase` and `FFNExpert`, + convert both. If the source has a `MoEMetrics` class, convert it. +22. **Preserve Default Values Exactly**: All default parameter values in the + JAX output must match the PyTorch source EXACTLY. Do NOT change any numeric + default -- not capacity factors, not dropout rates, not epsilon values, not + learning rates, not layer counts. Even if you believe a different value is + "better" or "more stable", use the source value. Changed defaults silently + alter model behavior and break reproducibility. +23. **Preserve Exact Reduction Operations**: When the source uses `.mean()`, + use `jnp.mean()`. When the source uses `.sum()`, use `jnp.sum()`. NEVER + substitute one reduction for another. `torch.mean(x, dim=N)` maps to + `jnp.mean(x, axis=N)`. `torch.sum(x, dim=N)` maps to `jnp.sum(x, axis=N)`. + The dim/axis integer stays the same. +24. **Preserve Method Placement**: If the source defines a method or attribute + on a specific class, keep it on that class in the JAX output. Do NOT + relocate methods between classes or replace instance methods with + standalone functions unless the JAX idiom requires it. ## CRITICAL: Faithfulness to Source Code +This is a TRANSLATION, not a redesign. The converted code must produce +IDENTICAL behavior to the source for the same inputs and weights. + NEVER simplify complex tensor reshaping, reordering, or algorithmic patterns from the source code. If the PyTorch code uses a specific interleaved weight layout, chunk-parallel algorithm, or multi-step computation, convert it faithfully to JAX. The RAG context shows EXAMPLES of similar patterns -- use them as guidance for JAX idioms, but always follow the ACTUAL source code's logic and structure. + +NEVER "improve" the source by changing default values, adding initializers +that the source does not use, substituting reductions (.sum vs .mean), or +dropping components you consider non-essential (logging, metrics, utility +classes). If the source has it, the output must have it. """ MIGRATE_MODULE_TO_JAX_PROMPT = """ @@ -357,6 +401,73 @@ linear attention, implement BOTH modes and dispatch based on sequence length. 5. Implement causal_conv1d as a standalone function with both prefill and single-step decode paths. +6. For causal operations with decode-time state (causal conv1d, linear + attention), implement SEPARATE prefill and decode functions. Do NOT use + a single unified function with conditional branching. +7. ALWAYS include a `@dataclasses.dataclass` Config class at the top of the + output file. Mirror ALL fields from the PyTorch configuration class with + their types and default values. Use `dataclasses.field(default_factory=...)` + for mutable defaults. Use the Config type (not `Any`) in module annotations. +8. The `load_balancing_loss` function MUST accept an optional `attention_mask` + parameter. When the mask is provided, broadcast it to match the concatenated + router logits shape and use it to exclude padding tokens from mean/sum + statistics. See the RAG context for the full pattern. +9. **MoE Experts: Capacity-Based Dispatch (MANDATORY)**. The Experts class MUST + use capacity-based dispatch with dispatch/combine tensors -- NOT per-token + gather of expert weights. The correct pattern is: + a) Compute per-expert capacity: `capacity = ceil(T * K / E) * 1.5` + b) Build dispatch tensor via `one_hot(selected_experts) -> cumsum -> positions + -> one_hot(positions, capacity)` to get `dispatch: [T, E, C]` + c) Build combine tensor: `combine = dispatch * routing_weights` + d) Route tokens to expert buffers: `expert_in = einsum('tec,th->ech', dispatch, x)` + e) Batched expert matmul: `expert_out = einsum('ech,ehi->eci', expert_in, W)` + f) Scatter back: `output = einsum('tec,ech->th', combine, expert_out)` + Do NOT use `weight[flat_indices]` gather or `jax.vmap` over individual experts. + Do NOT use `jnp.einsum('td,edh->teh')` computing all experts for all tokens. + The capacity-based approach is 10-50x more efficient for large E (e.g. E=64). + See the RAG context file `targeted_moe_capacity_routing_jax.py` for the full + implementation with WRONG/CORRECT examples. +10. **KV Cache: Pure Functional NamedTuple (MANDATORY)**. All KV caches MUST be + NamedTuple objects passed as function arguments and returned as outputs. + Do NOT use Flax mutable variables (`self.variable('cache', ...)`). + Do NOT use config dicts with init flags. + For encoder-decoder models: use SEPARATE self_attn_cache and cross_attn_cache + arguments per layer. Cross-attention caches are populated once from encoder + output and passed through unchanged on subsequent decode steps. + Provide an `init_kv_caches()` helper function that pre-allocates all layer + caches. This replaces PyTorch's `install_kv_cache_hooks()`. + See the RAG context for the full encoder-decoder cache pattern. +11. **Tied Output Projection**: When the PyTorch source computes logits via + `x @ self.token_embedding.weight.T`, convert it to + `(x @ token_embedding.embedding.T).astype(jnp.float32)`. + Do NOT use `token_embedding.attend(x)` -- that is for embedding lookup, + not linear projection, and may produce different results. +12. **Fused QKV Projection**: When the PyTorch source uses a single + `in_proj_weight` of shape [3*embed_dim, embed_dim] with sliced projection + methods (in_proj_qkv, in_proj_q, in_proj_kv), preserve this as a SINGLE + parameter with sliced access in JAX. Do NOT split into 3 separate nn.Dense + layers. Use `self.param('in_proj_weight', init, (3*D, D))` and slice it + for Q [0:D], K [D:2D], V [2D:3D]. Provide in_proj_qkv(), in_proj_q(), + in_proj_kv() methods matching the PyTorch API. +13. **Float32 Softmax Upcast (MANDATORY)**: When the PyTorch source uses + `.float()` or `dtype=torch.float32` before softmax, you MUST preserve this + in JAX: `jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1)` then + cast back with `.astype(q.dtype)`. This is critical for numerical stability + in bfloat16/float16. NEVER omit this upcast. +14. **Preserve ALL Source Components (MANDATORY)**: The output MUST contain a + JAX equivalent for EVERY class, function, method, and utility in the source. + Do NOT merge base classes into subclasses. Do NOT drop get_config() or + serialization methods. Do NOT omit utility classes (e.g., metrics classes) + or standalone functions (e.g., metric computation functions). If the source + has N classes and M functions, the output must have N classes and M functions. +15. **Preserve Default Values Exactly**: All constructor defaults, config + defaults, and hyperparameter defaults MUST match the PyTorch source exactly. + Do NOT change capacity_factor, dropout rates, noise epsilon, num_layers, + or any other default value -- even if you think a different value is better. +16. **Train/Eval Mode in Flax**: NEVER set `model.train = True/False` or call + `model.eval()` / `model.train()` in training loops. Flax has no such + attributes. Use `deterministic=False` for training and `deterministic=True` + for evaluation, passed as an argument to the module's `__call__` method. Please think step by step about the conversion process before generating the code. Then, provide the complete JAX equivalent of the entire file above. diff --git a/MaxCode/agents/migration/repo_agent.py b/MaxCode/agents/migration/repo_agent.py index 0688ef4..abe3667 100644 --- a/MaxCode/agents/migration/repo_agent.py +++ b/MaxCode/agents/migration/repo_agent.py @@ -43,7 +43,7 @@ def run(self, repo_path: str) -> Dict[str, str]: try: with open(file_path, "r") as f: pytorch_code = f.read() - rag_context_list = self._rag_agent.retrieve_context(pytorch_code) + rag_context_list = self._rag_agent.retrieve_per_component_context(pytorch_code) rag_context = "\\n\\n".join([ f"File: {c['file']}\\n```python\\n{c['text']}\\n```" for c in rag_context_list diff --git a/MaxCode/agents/migration/single_file_agent.py b/MaxCode/agents/migration/single_file_agent.py index 7bc991a..aa84e13 100644 --- a/MaxCode/agents/migration/single_file_agent.py +++ b/MaxCode/agents/migration/single_file_agent.py @@ -35,6 +35,18 @@ def _strip_markdown_formatting(self, text: str) -> str: ) if code_block_match: return code_block_match.group(1).strip() + # Handle truncated responses: opening ``` present but closing ``` missing + stripped = text.strip() + if stripped.startswith("```"): + first_nl = stripped.find("\n") + if first_nl != -1: + stripped = stripped[first_nl + 1:] + if stripped.endswith("```"): + stripped = stripped[:-3] + return stripped.strip() + # Strip triple-quote wrappers the LLM may use instead of backticks. + if stripped.startswith('"""') and stripped.endswith('"""'): + return stripped[3:-3].strip() return text def run(self, pytorch_code: str) -> str: @@ -46,7 +58,7 @@ def run(self, pytorch_code: str) -> str: Returns: The converted JAX code. """ - rag_context_list = self._rag_agent.retrieve_context(pytorch_code, top_k=7) + rag_context_list = self._rag_agent.retrieve_per_component_context(pytorch_code) rag_context = "\n\n".join([ f"File: {c['file']}\n```python\n{c['text']}\n```" for c in rag_context_list diff --git a/MaxCode/agents/migration/validation_agent.py b/MaxCode/agents/migration/validation_agent.py new file mode 100644 index 0000000..e55a59c --- /dev/null +++ b/MaxCode/agents/migration/validation_agent.py @@ -0,0 +1,352 @@ +"""Agent for validating faithfulness of PyTorch-to-JAX conversions.""" + +import json +import re +from typing import Any + +from agents import base +from agents import utils + + +VALIDATION_PROMPT = """You are an expert code reviewer specializing in PyTorch-to-JAX +conversions. Your task is to compare the ORIGINAL PyTorch source code with the +CONVERTED JAX/Flax output and identify every FAITHFULNESS DEVIATION. + +A faithfulness deviation is any place where the JAX output CHANGES the behavior, +defaults, structure, or semantics of the original PyTorch code. You should NOT +flag intentional JAX idiom changes (e.g., torch.Tensor -> jnp.ndarray, +nn.Module -> nn.Module with @nn.compact, self.training -> deterministic flag). + +## Original PyTorch Source: +```python +{pytorch_code} +``` + +## Converted JAX Output: +```python +{jax_code} +``` + +## Check each of the following categories: + +### 1. Default Values +Compare every constructor parameter default in the source vs the output. +Flag any changed numeric value (e.g., capacity_factor=1.0 changed to 1.25). + +### 2. Weight Initialization +For each nn.Linear/nn.Dense in the source: +- If the source uses bare `nn.Linear(...)` with NO explicit init call + (no nn.init.zeros_, nn.init.normal_, etc.), the JAX output should use + the Flax default initializer (no kernel_init argument). +- If the source EXPLICITLY calls an init (e.g., nn.init.zeros_), the JAX + output should use the matching Flax initializer. +Flag any case where an initializer was added or changed. + +### 3. Missing Components +List every class, function, method, or constant in the source that has +NO equivalent in the JAX output. Include: +- Base classes that were merged into subclasses +- get_config() or serialization methods +- Utility functions (metrics, logging helpers, etc.) +- Utility classes (e.g., metrics aggregation classes) +- Lambda attributes or property methods + +### 4. Reduction Operations +Flag any place where .mean() was changed to .sum() or vice versa, +or where a reduction axis was changed. + +### 5. Method Placement +Flag any method/attribute that was moved from one class to another, +or converted from an instance method to a standalone function when +the source has it as a method. + +### 6. Dropped Features +Flag any feature present in the source that was removed in the output +(e.g., TensorBoard logging, checkpoint saving, progress bars, etc.) + +## IMPORTANT: Use Exact Code Snippets +When reporting deviations, copy the relevant lines VERBATIM from the code +above. Do NOT paraphrase or describe the code in English. Use the actual +source and output lines so that a repair tool can find-and-replace them. + +## Output Format + +Return a JSON array of deviations. Each deviation must have: +- "category": one of "default_value", "initialization", "missing_component", + "reduction_op", "method_placement", "dropped_feature" +- "severity": "high" (changes model output), "medium" (changes training behavior), + or "low" (cosmetic or minor) +- "source_snippet": copy the exact line(s) verbatim from the PyTorch source + (max 3 lines). For missing components, show the class/function signature. +- "output_snippet": copy the exact line(s) verbatim from the JAX output + (max 3 lines). Use "MISSING" if the component does not exist. +- "corrected_snippet": the exact replacement code that should replace + output_snippet to fix the deviation. Use "ADD" for missing components + (and put the new code in the fix field). +- "fix": specific instruction for how to fix the deviation + +If there are NO deviations, return an empty array: [] + +Return ONLY the JSON array, no markdown formatting, no explanation. +""" + + +REPAIR_PROMPT = """You are an expert JAX/Flax developer. You have been given a +JAX/Flax code file that was converted from PyTorch, along with a list of +faithfulness deviations that need to be fixed. + +## Original PyTorch Source (for reference): +```python +{pytorch_code} +``` + +## Current JAX Code: +```python +{jax_code} +``` +{rag_section} +## Deviations to Fix: +{deviations_text} + +## CRITICAL RULES: +1. For each deviation, find the EXACT output_snippet in the JAX code and + replace it with the corrected_snippet. If the snippets are not exact + matches (whitespace differences, etc.), locate the closest match and + apply the fix described in the instruction. +2. NEVER remove an existing class, function, method, or import -- even if it + seems unused or redundant. If the current JAX code has a class (e.g., + MoETrainer, MoEMetrics), it MUST remain in the output. +3. NEVER convert a class into standalone functions or vice versa. +4. NEVER remove a training loop, epoch loop, or any training utility code. +5. If a deviation's instruction says the current behavior is acceptable, + desirable, or "not recommended" to change, SKIP that deviation entirely. +6. Preserve ALL existing code structure -- only change what the deviation + specifically asks you to change. +7. The output must have the SAME number of classes and functions (or more) + as the input JAX code. + +Return ONLY the complete fixed Python code. No markdown formatting, no +explanation, no ```python blocks. +""" + + +_CODE_BLOCK_PATTERN = re.compile(r"```(?:python)?\n?(.*?)\n?```", re.DOTALL) + + +def _strip_markdown_formatting(text: str) -> str: + """Strips markdown and returns only the first Python code block.""" + code_block_match = _CODE_BLOCK_PATTERN.search(text) + if code_block_match: + return code_block_match.group(1).strip() + # Strip triple-quote wrappers the LLM may use instead of backticks. + stripped = text.strip() + if stripped.startswith('"""') and stripped.endswith('"""'): + return stripped[3:-3].strip() + return text + + +def _parse_json_response(text: str) -> list: + """Parse JSON from LLM response, handling markdown wrapping.""" + text = text.strip() + # Strip markdown code blocks if present + json_match = re.search(r"```(?:json)?\n?(.*?)\n?```", text, re.DOTALL) + if json_match: + text = json_match.group(1).strip() + try: + return json.loads(text) + except json.JSONDecodeError: + # Try to find a JSON array in the text + array_match = re.search(r'\[.*\]', text, re.DOTALL) + if array_match: + try: + return json.loads(array_match.group(0)) + except json.JSONDecodeError: + pass + return [] + + +class ValidationAgent(base.Agent): + """Agent for validating faithfulness of PyTorch-to-JAX conversions. + + This agent takes the original PyTorch source and the converted JAX output, + identifies faithfulness deviations (changed defaults, wrong init, missing + components, altered semantics), and optionally repairs them. + """ + + def __init__(self, model: Any, rag_agent_instance=None): + """Initializes the agent. + + Args: + model: The LLM model to use for generation. + rag_agent_instance: Optional RAGAgent for retrieving context + during repair. If None, repair runs without RAG context. + """ + super().__init__( + model=model, + agent_domain=utils.AgentDomain.MIGRATION, + agent_type=utils.AgentType.PRIMARY, + ) + self._rag_agent = rag_agent_instance + + def validate(self, pytorch_code: str, jax_code: str) -> list: + """Validates the JAX output against the PyTorch source. + + Args: + pytorch_code: The original PyTorch source code. + jax_code: The converted JAX/Flax code. + + Returns: + A list of deviation dicts, each with category, severity, + source_line, output_line, and fix fields. + """ + response = self.generate( + VALIDATION_PROMPT, + {"pytorch_code": pytorch_code, "jax_code": jax_code}, + ) + return _parse_json_response(response) + + @staticmethod + def _filter_actionable(deviations: list) -> list: + """Filter out deviations that explicitly say not to fix.""" + skip_phrases = [ + "not recommended", + "desirable deviation", + "correct and desirable", + "overly complex", + "acceptable deviation", + ] + actionable = [] + for d in deviations: + fix_text = d.get("fix", "").lower() + if any(phrase in fix_text for phrase in skip_phrases): + continue + actionable.append(d) + return actionable + + @staticmethod + def _format_deviations_for_repair(deviations: list) -> str: + """Formats deviations as numbered find/replace blocks for repair. + + Falls back to old source_line/output_line fields if the new + source_snippet/output_snippet fields are absent. + + Args: + deviations: List of deviation dicts from validate(). + + Returns: + A formatted string with numbered find/replace blocks. + """ + blocks = [] + for i, d in enumerate(deviations, 1): + severity = d.get("severity", "medium") + category = d.get("category", "unknown") + source = d.get("source_snippet", d.get("source_line", "N/A")) + output = d.get("output_snippet", d.get("output_line", "N/A")) + corrected = d.get("corrected_snippet", "") + fix = d.get("fix", "") + + block = f"### Deviation {i} [{severity}] - {category}\n" + block += f"Source (PyTorch): {source}\n" + block += f"Find in JAX: {output}\n" + if output == "MISSING": + block += f"Source to convert: {source}\n" + if corrected and corrected not in ("ADD", "MISSING"): + block += f"Replace with: {corrected}\n" + block += f"Instruction: {fix}" + blocks.append(block) + return "\n\n".join(blocks) + + def _get_repair_rag_context(self, deviations: list) -> str: + """Retrieves RAG context relevant to the repair deviations. + + Builds a query from deviation categories and fix text, retrieves + top-k documents, and returns a formatted string for the prompt. + + Args: + deviations: List of deviation dicts from validate(). + + Returns: + A formatted RAG context string, or "" if no RAG agent. + """ + if not self._rag_agent: + return "" + + # Build query from deviation categories and fix descriptions + query_parts = [] + for d in deviations: + category = d.get("category", "") + fix = d.get("fix", "") + if category: + query_parts.append(category.replace("_", " ")) + if fix: + query_parts.append(fix) + query = " ".join(query_parts) + if not query.strip(): + return "" + + try: + docs = self._rag_agent.retrieve_context(query, top_k=15) + except Exception: + return "" + + if not docs: + return "" + + section = "\n## Reference Patterns (from RAG):\n" + for doc in docs: + name = doc.get("name", "unknown") + text = doc.get("text", "") + section += f"\n### {name}\n{text}\n" + return section + + def repair(self, jax_code: str, deviations: list, + pytorch_code: str = "") -> str: + """Repairs the JAX code based on identified deviations. + + Args: + jax_code: The converted JAX/Flax code to repair. + deviations: List of deviation dicts from validate(). + pytorch_code: The original PyTorch source for reference. + + Returns: + The repaired JAX code. + """ + # Filter to only actionable deviations + actionable = self._filter_actionable(deviations) + if not actionable: + return jax_code + + deviations_text = self._format_deviations_for_repair(actionable) + rag_section = self._get_repair_rag_context(actionable) + response = self.generate( + REPAIR_PROMPT, + { + "jax_code": jax_code, + "deviations_text": deviations_text, + "rag_section": rag_section, + "pytorch_code": pytorch_code, + }, + ) + repaired = _strip_markdown_formatting(response) + # If the repair returned empty or very short, return original + if len(repaired) < len(jax_code) * 0.5: + return jax_code + return repaired + + def run(self, pytorch_code: str, jax_code: str) -> tuple: + """Validates and optionally repairs the conversion. + + Args: + pytorch_code: The original PyTorch source code. + jax_code: The converted JAX/Flax code. + + Returns: + Tuple of (repaired_code, deviations_list). + """ + deviations = self.validate(pytorch_code, jax_code) + if deviations: + repaired_code = self.repair( + jax_code, deviations, pytorch_code=pytorch_code + ) + return repaired_code, deviations + return jax_code, [] diff --git a/MaxCode/agents/migration/verification_agent.py b/MaxCode/agents/migration/verification_agent.py new file mode 100644 index 0000000..133ffcf --- /dev/null +++ b/MaxCode/agents/migration/verification_agent.py @@ -0,0 +1,272 @@ +"""Verification agent for scoring PyTorch-to-JAX conversion quality. + +Produces a scorecard with two metrics: + - Completeness (AST-based, no LLM): compares classes, methods, and + standalone functions by name. + - Correctness (LLM-based, requires API key): runs ValidationAgent to + detect deviations and scores them with weighted penalties. +""" + +import ast +from dataclasses import dataclass, field + + +@dataclass +class VerificationResult: + """Result of verifying a conversion.""" + completeness: dict = field(default_factory=dict) # score, total, found, classes, methods, functions + correctness: dict | None = None # score, deviations, by_category, by_severity (None if no api_key) + overall: float = 0.0 + + +# Standard PyTorch -> JAX/Flax method renames. +METHOD_RENAMES = { + "__init__": {"setup", "__call__"}, + "forward": {"__call__"}, +} + +# Methods always inlined during conversion. +ALWAYS_INLINED = { + "reset_parameters", +} + +# Severity weights for correctness scoring. +SEVERITY_WEIGHTS = {"high": 5, "medium": 3, "low": 1} + +# Known false-positive (category, severity) pairs. +FALSE_POSITIVE_RULES = { + ("method_placement", "low"), + ("missing_component", "low"), + ("dropped_feature", "low"), +} + + +class VerificationAgent: + """Scores the quality of a PyTorch-to-JAX conversion. + + The completeness check is pure AST (no LLM). The correctness check + delegates to ValidationAgent for deviation detection and applies + weighted scoring. + """ + + def __init__(self, model=None): + """Initialize the verification agent. + + Args: + model: Optional LLM model instance for correctness checks. + If None, correctness scoring is skipped. + """ + self._model = model + + @staticmethod + def extract_components(code): + """Parse Python code and return its classes, methods, and functions. + + Args: + code: Python source code string. + + Returns: + dict with keys "classes" (name -> [methods]) and "functions" (list). + """ + tree = ast.parse(code) + classes = {} + functions = [] + + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + methods = [ + n.name + for n in ast.iter_child_nodes(node) + if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) + ] + classes[node.name] = methods + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + functions.append(node.name) + + return {"classes": classes, "functions": functions} + + @staticmethod + def compute_completeness(source_components, output_components): + """Compare source and output components and return a completeness report. + + Returns: + dict with score, total, found, classes, methods, functions breakdown. + """ + src_classes = source_components["classes"] + out_classes = output_components["classes"] + + src_class_names = set(src_classes.keys()) + out_class_names = set(out_classes.keys()) + matched_classes = src_class_names & out_class_names + missing_classes = sorted(src_class_names - out_class_names) + + total_methods = 0 + found_methods = 0 + missing_methods = [] + + for cls in src_classes: + src_methods = set(src_classes[cls]) + total_methods += len(src_methods) + if cls in out_classes: + out_methods = set(out_classes[cls]) + has_call = "__call__" in out_methods + for m in sorted(src_methods): + if m in out_methods: + found_methods += 1 + elif m in METHOD_RENAMES and METHOD_RENAMES[m] & out_methods: + found_methods += 1 + elif m in ALWAYS_INLINED: + found_methods += 1 + elif has_call and m not in ("__init__", "forward"): + found_methods += 1 + else: + missing_methods.append(f"{cls}.{m}") + else: + for m in sorted(src_methods): + missing_methods.append(f"{cls}.{m}") + + src_funcs = set(source_components["functions"]) + out_funcs = set(output_components["functions"]) + matched_funcs = src_funcs & out_funcs + for f in src_funcs - matched_funcs: + if f in out_class_names: + matched_funcs = matched_funcs | {f} + missing_funcs = sorted(src_funcs - matched_funcs) + + total = len(src_class_names) + total_methods + len(src_funcs) + found = len(matched_classes) + found_methods + len(matched_funcs) + score = (found / total * 100) if total > 0 else 100.0 + + return { + "score": round(score, 1), + "total": total, + "found": found, + "classes": { + "total": len(src_class_names), + "found": len(matched_classes), + "missing": missing_classes, + }, + "methods": { + "total": total_methods, + "found": found_methods, + "missing": missing_methods, + }, + "functions": { + "total": len(src_funcs), + "found": len(matched_funcs), + "missing": missing_funcs, + }, + } + + @staticmethod + def compute_correctness(source_code, output_code, api_key, + total_components=0, model=None): + """Run ValidationAgent and score the output. + + Args: + source_code: The PyTorch source code. + output_code: The converted JAX output code. + api_key: Google API key for the LLM. + total_components: Number of source components for budget scaling. + model: Optional pre-configured LLM model instance. If None, + creates a new GeminiTool with the given api_key. + + Returns: + dict with score, deviation_count, deviations, filtered_deviations, + by_category, by_severity. + """ + import models + from agents.migration.validation_agent import ValidationAgent + + if model is None: + model = models.GeminiTool( + model_name=models.GeminiModel.GEMINI_3_1_PRO_PREVIEW, + api_key=api_key, + ) + validator = ValidationAgent(model=model) + all_deviations = validator.validate(source_code, output_code) + + if not isinstance(all_deviations, list): + all_deviations = [] + + real = [] + filtered = [] + for d in all_deviations: + sev = d.get("severity", "low").lower() + cat = d.get("category", "unknown") + if (cat, sev) in FALSE_POSITIVE_RULES: + filtered.append(d) + else: + real.append(d) + + by_severity = {} + by_category = {} + penalty = 0 + + for d in real: + sev = d.get("severity", "low").lower() + cat = d.get("category", "unknown") + by_severity[sev] = by_severity.get(sev, 0) + 1 + by_category[cat] = by_category.get(cat, 0) + 1 + penalty += SEVERITY_WEIGHTS.get(sev, 1) + + if total_components <= 0: + try: + tree = ast.parse(source_code) + total_components = sum( + 1 for n in ast.iter_child_nodes(tree) + if isinstance(n, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) + ) + except SyntaxError: + total_components = 0 + + budget = total_components * SEVERITY_WEIGHTS["medium"] + if budget > 0: + score = max(0.0, (1.0 - penalty / budget) * 100.0) + else: + score = 100.0 if penalty == 0 else 0.0 + + return { + "score": round(score, 1), + "deviation_count": len(real), + "deviations": real, + "filtered_deviations": filtered, + "by_category": by_category, + "by_severity": by_severity, + } + + def verify(self, source_code, output_code, api_key=None): + """Run full verification (completeness + optional correctness). + + Args: + source_code: The PyTorch source code string. + output_code: The converted JAX output code string. + api_key: Optional Google API key. If provided (or if self._model + is set), runs correctness check. + + Returns: + VerificationResult with completeness, correctness, and overall score. + """ + src_components = self.extract_components(source_code) + out_components = self.extract_components(output_code) + completeness = self.compute_completeness(src_components, out_components) + + correctness = None + if api_key or self._model: + correctness = self.compute_correctness( + source_code, output_code, + api_key=api_key, + total_components=completeness["total"], + model=self._model, + ) + + if correctness is not None: + overall = round((completeness["score"] + correctness["score"]) / 2, 1) + else: + overall = completeness["score"] + + return VerificationResult( + completeness=completeness, + correctness=correctness, + overall=overall, + ) diff --git a/MaxCode/mcp_server/adk_agents.py b/MaxCode/mcp_server/adk_agents.py index 31eeb45..d0a307d 100644 --- a/MaxCode/mcp_server/adk_agents.py +++ b/MaxCode/mcp_server/adk_agents.py @@ -3,6 +3,7 @@ import models from tools import evaluation_tool from tools import migration_tool +from tools import verification_tool from google.adk.agents.llm_agent import LlmAgent as Agent from google.adk.models.google_llm import Gemini @@ -39,6 +40,7 @@ Always wait for a tool to succeed before moving to the next step. If a step fails, report the error immediately and stop.""", tools=[ migration_tool.convert_code_tool, + verification_tool.verify_conversion_tool, evaluation_tool.generate_model_configs_tool, evaluation_tool.generate_oracle_data_tool, evaluation_tool.run_equivalence_tests_tool, @@ -66,5 +68,6 @@ evaluation_tool.generate_oracle_data_tool, evaluation_tool.generate_equivalence_tests_tool, evaluation_tool.run_equivalence_tests_tool, + verification_tool.verify_conversion_tool, ], ) diff --git a/MaxCode/models.py b/MaxCode/models.py index 240c934..d38133e 100644 --- a/MaxCode/models.py +++ b/MaxCode/models.py @@ -69,7 +69,12 @@ def __call__(self, user_prompt: str): str: The generated text response from the Gemini API. """ headers = {"Content-Type": "application/json"} - payload = {"contents": [{"parts": [{"text": user_prompt}], "role": "user"}]} + payload = { + "contents": [{"parts": [{"text": user_prompt}], "role": "user"}], + "generationConfig": { + "maxOutputTokens": 65536, + }, + } if self.system_instruction: payload["system_instruction"] = { "parts": [{"text": self.system_instruction}] diff --git a/MaxCode/rag/rag_agent.py b/MaxCode/rag/rag_agent.py index 45d651e..adf7ad1 100644 --- a/MaxCode/rag/rag_agent.py +++ b/MaxCode/rag/rag_agent.py @@ -1,5 +1,7 @@ """Tool for performing retrieval augmented generation.""" +import ast +import logging import os import sqlite3 from typing import Any, Dict, List @@ -11,6 +13,8 @@ from rag import vector_db import numpy as np +logger = logging.getLogger(__name__) + # We use a hardcoded character limit for the full code context to avoid # exceeding the model's token limit. While the Gemini API does not provide a @@ -20,6 +24,50 @@ _MAX_CONTEXT_LENGTH = 100_000 +def _extract_component_signatures(code: str) -> list[str]: + """Extracts focused query strings per top-level class/function using AST. + + For classes: "JAX Flax {ClassName} {base_classes} {method_names} {init_params}" + For functions: "JAX Flax {func_name} {param_names}" + + Args: + code: Python source code to parse. + + Returns: + A list of query strings, one per top-level component. + """ + try: + tree = ast.parse(code) + except SyntaxError: + return [] + + signatures = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + bases = [ + ast.unparse(b) if hasattr(ast, "unparse") else getattr(b, "id", "") + for b in node.bases + ] + methods = [ + n.name for n in ast.iter_child_nodes(node) + if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) + ] + init_params = [] + for n in ast.iter_child_nodes(node): + if isinstance(n, ast.FunctionDef) and n.name == "__init__": + init_params = [ + a.arg for a in n.args.args if a.arg != "self" + ] + break + parts = ["JAX Flax", node.name] + bases + methods + init_params + signatures.append(" ".join(parts)) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + params = [a.arg for a in node.args.args if a.arg != "self"] + parts = ["JAX Flax", node.name] + params + signatures.append(" ".join(parts)) + return signatures + + class RAGAgent(base.Agent): """Tool for performing retrieval augmented generation.""" @@ -116,6 +164,63 @@ def retrieve_context( }) return retrieved_context + def retrieve_per_component_context( + self, + source_code: str, + top_k_per_component: int = 3, + max_total: int = 15, + ) -> List[Dict[str, Any]]: + """Retrieves RAG context using a hybrid full-file + per-component strategy. + + Combines broad domain context from the full source code query with + targeted results from per-component queries. This ensures the LLM gets + both the overall architectural patterns AND component-specific examples. + + Args: + source_code: The full Python source code to retrieve context for. + top_k_per_component: Number of results per component query. + max_total: Maximum total results to return after deduplication. + + Returns: + A deduplicated, distance-sorted list of retrieved documents. + """ + signatures = _extract_component_signatures(source_code) + + # Fall back to single-query if AST parsing yielded nothing + if not signatures: + logger.info("Per-component extraction failed, falling back to single query") + return self.retrieve_context(source_code, top_k=max_total) + + # Start with full-file query for broad domain context + best_by_file: Dict[str, Dict[str, Any]] = {} + full_results = self.retrieve_context(source_code, top_k=max_total) + for doc in full_results: + best_by_file[doc["file"]] = doc + + # If >12 components, batch into groups of 3-4 to cap embedding calls + if len(signatures) > 12: + batched = [] + for i in range(0, len(signatures), 4): + batched.append(" ".join(signatures[i:i + 4])) + queries = batched + else: + queries = signatures + + logger.info("Per-component RAG: %d queries from %d components (+ full-file)", + len(queries), len(signatures)) + + # Add per-component results, keeping best distance per file + for query in queries: + results = self.retrieve_context(query, top_k=top_k_per_component) + for doc in results: + fpath = doc["file"] + if fpath not in best_by_file or doc["distance"] < best_by_file[fpath]["distance"]: + best_by_file[fpath] = doc + + # Sort by distance, truncate to max_total + sorted_docs = sorted(best_by_file.values(), key=lambda d: d["distance"]) + return sorted_docs[:max_total] + def run(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]: """Runs RAG to retrieve context for a query.""" return self.retrieve_context(query, top_k) diff --git a/MaxCode/tools/migration_tool.py b/MaxCode/tools/migration_tool.py index a44f5a1..5864f79 100644 --- a/MaxCode/tools/migration_tool.py +++ b/MaxCode/tools/migration_tool.py @@ -31,6 +31,7 @@ def convert_code( destination: str, api_key: str, model_name: str | None = None, + validate: bool = True, ) -> str: """Converts PyTorch code to JAX and saves it to the destination. @@ -39,6 +40,7 @@ def convert_code( destination: The directory where the migrated files should be saved. api_key: The Google AI API key to use for migration. model_name: The Gemini model to use for migration. + validate: Whether to run validation and repair after conversion. Returns: A JSON string containing the destination paths for subsequent steps. @@ -67,7 +69,7 @@ def convert_code( if model_name: model_kwargs["model_name"] = model_name model = models.GeminiTool(**model_kwargs) - agent = primary_agent.PrimaryAgent(model, api_key=api_key) + agent = primary_agent.PrimaryAgent(model, api_key=api_key, validate=validate) results = agent.run(abs_path) logging.info("Writing converted files to: %s", destination) @@ -91,21 +93,49 @@ def convert_code( "error": f"Failed to copy source files to destination: {e}", }) + # Handle two result formats: + # - Merge path (directory): keys are "model" and optionally "utils" + # - Single-file / legacy path: keys are file paths + is_merge_result = "model" in results written_files = [] mapping_log = [] - for file_path, code in results.items(): - if is_dir: - relative_path = pathlib.Path(file_path).relative_to(p) - else: - relative_path = pathlib.Path(file_path).name - output_path = dest_path / relative_path - _write_artifact(output_path, code) - written_files.append(output_path) + + if is_merge_result: + # Write model output + model_output = dest_path / "model_jax.py" + _write_artifact(model_output, results["model"]) + written_files.append(model_output) mapping_log.append({ - "source_file": file_path, - "generated_file": str(output_path), + "source_file": abs_path, + "generated_file": str(model_output), + "component": "model", "status": "success", }) + # Write utils output (if present) + if "utils" in results: + utils_output = dest_path / "utils_jax.py" + _write_artifact(utils_output, results["utils"]) + written_files.append(utils_output) + mapping_log.append({ + "source_file": abs_path, + "generated_file": str(utils_output), + "component": "utils", + "status": "success", + }) + else: + for file_path, code in results.items(): + if is_dir: + relative_path = pathlib.Path(file_path).relative_to(p) + else: + relative_path = pathlib.Path(file_path).name + output_path = dest_path / relative_path + _write_artifact(output_path, code) + written_files.append(output_path) + mapping_log.append({ + "source_file": file_path, + "generated_file": str(output_path), + "status": "success", + }) # Create __init__.py files for all directories containing migrated files. dirs_in_results = set(f.parent for f in written_files) @@ -140,6 +170,63 @@ def convert_code( "mapping_path": str(mapping_path), "original_source_dir": str(source_copy_dir), } + + # Write validation results if validation was enabled and produced results + validation_results = agent.get_validation_results() + if validate and validation_results: + validation_path = dest_path / "validation_results.json" + with validation_path.open("w", encoding="utf-8") as f: + json.dump(validation_results, f, indent=2) + response["validation_path"] = str(validation_path) + + # Auto-verify converted files + try: + from agents.migration.verification_agent import VerificationAgent + verifier = VerificationAgent() + scorecard = {} + + if is_merge_result: + # Use cached merge result from PrimaryAgent to avoid re-running merge + cached_merge = agent.get_merge_result() + if cached_merge: + source_code_map = {"model": cached_merge.model_code} + if cached_merge.utility_code: + source_code_map["utils"] = cached_merge.utility_code + else: + with open(abs_path, "r", encoding="utf-8", errors="replace") as f: + source_code_map = {"model": f.read()} + + for component, jax_code in results.items(): + if component in source_code_map: + vr = verifier.verify(source_code_map[component], jax_code) + scorecard[component] = { + "completeness": vr.completeness, + "overall": vr.overall, + } + else: + for file_path, jax_code in results.items(): + try: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + src = f.read() + vr = verifier.verify(src, jax_code) + scorecard[file_path] = { + "completeness": vr.completeness, + "overall": vr.overall, + } + except OSError: + pass + + if scorecard: + scorecard_path = dest_path / "verification_scorecard.json" + with scorecard_path.open("w", encoding="utf-8") as f: + json.dump(scorecard, f, indent=2) + response["verification_scorecard_path"] = str(scorecard_path) + response["verification_summary"] = { + k: v["overall"] for k, v in scorecard.items() + } + except Exception as e: + logging.warning("Auto-verification failed: %s", e) + return json.dumps(response) diff --git a/MaxCode/tools/verification_tool.py b/MaxCode/tools/verification_tool.py new file mode 100644 index 0000000..1ad29f1 --- /dev/null +++ b/MaxCode/tools/verification_tool.py @@ -0,0 +1,69 @@ +"""Verification tool for ADK — scores PyTorch-to-JAX conversion quality.""" + +import json +import logging + +from agents.migration.verification_agent import VerificationAgent +from google.adk.tools.function_tool import FunctionTool + + +def verify_conversion( + source_path: str, + output_path: str, + api_key: str = "", +) -> str: + """Verify quality of a PyTorch-to-JAX conversion. + + Computes a completeness score (AST-based) and optionally a correctness + score (LLM-based, requires api_key). Returns JSON with both scores and + an overall score. + + Args: + source_path: Path to the original PyTorch source file. + output_path: Path to the converted JAX output file. + api_key: Optional Google AI API key for LLM-based correctness check. + + Returns: + A JSON string with completeness, correctness, and overall scores. + """ + logging.info( + "verify_conversion called with source_path=%s, output_path=%s", + source_path, output_path, + ) + + try: + with open(source_path, "r", encoding="utf-8") as f: + source_code = f.read() + except OSError as e: + return json.dumps({"error": f"Cannot read source file: {e}"}) + + try: + with open(output_path, "r", encoding="utf-8") as f: + output_code = f.read() + except OSError as e: + return json.dumps({"error": f"Cannot read output file: {e}"}) + + verifier = VerificationAgent() + result = verifier.verify( + source_code, output_code, + api_key=api_key if api_key else None, + ) + + response = { + "source_path": source_path, + "output_path": output_path, + "completeness": result.completeness, + "overall": result.overall, + } + if result.correctness is not None: + response["correctness"] = { + "score": result.correctness["score"], + "deviation_count": result.correctness["deviation_count"], + "by_category": result.correctness["by_category"], + "by_severity": result.correctness["by_severity"], + } + + return json.dumps(response) + + +verify_conversion_tool = FunctionTool(verify_conversion)