diff --git a/MaxCode/agents/migration/validation_agent.py b/MaxCode/agents/migration/validation_agent.py new file mode 100644 index 0000000..a3ed05a --- /dev/null +++ b/MaxCode/agents/migration/validation_agent.py @@ -0,0 +1,399 @@ +"""Agent for validating faithfulness of PyTorch-to-target conversions.""" + +import json +import re +from typing import Any + +from agents import base +from agents import utils +from agents.migration.prompts import prompts + + +VALIDATION_PROMPT = """You are an expert code reviewer specializing in PyTorch-to-JAX +conversions. Your task is to compare the ORIGINAL PyTorch source code with the +CONVERTED JAX/Flax output and identify every FAITHFULNESS DEVIATION. + +A faithfulness deviation is any place where the JAX output CHANGES the behavior, +defaults, structure, or semantics of the original PyTorch code. You should NOT +flag intentional JAX idiom changes (e.g., torch.Tensor -> jnp.ndarray, +nn.Module -> nn.Module with @nn.compact, self.training -> deterministic flag). + +## Original PyTorch Source: +```python +{pytorch_code} +``` + +## Converted JAX Output: +```python +{target_code} +``` + +## Check each of the following categories: + +### 1. Default Values +Compare every constructor parameter default in the source vs the output. +Flag any changed numeric value (e.g., capacity_factor=1.0 changed to 1.25). + +### 2. Weight Initialization +For each nn.Linear/nn.Dense in the source: +- If the source uses bare `nn.Linear(...)` with NO explicit init call + (no nn.init.zeros_, nn.init.normal_, etc.), the JAX output should use + the Flax default initializer (no kernel_init argument). +- If the source EXPLICITLY calls an init (e.g., nn.init.zeros_), the JAX + output should use the matching Flax initializer. +Flag any case where an initializer was added or changed. + +### 3. Missing Components +List every class, function, method, or constant in the source that has +NO equivalent in the JAX output. Include: +- Base classes that were merged into subclasses +- get_config() or serialization methods +- Utility functions (metrics, logging helpers, etc.) +- Utility classes (e.g., metrics aggregation classes) +- Lambda attributes or property methods + +### 4. Reduction Operations +Flag any place where .mean() was changed to .sum() or vice versa, +or where a reduction axis was changed. + +### 5. Method Placement +Flag any method/attribute that was moved from one class to another, +or converted from an instance method to a standalone function when +the source has it as a method. + +### 6. Dropped Features +Flag any feature present in the source that was removed in the output +(e.g., TensorBoard logging, checkpoint saving, progress bars, etc.) + +## IMPORTANT: Use Exact Code Snippets +When reporting deviations, copy the relevant lines VERBATIM from the code +above. Do NOT paraphrase or describe the code in English. Use the actual +source and output lines so that a repair tool can find-and-replace them. + +## Output Format + +Return a JSON array of deviations. Each deviation must have: +- "category": one of "default_value", "initialization", "missing_component", + "reduction_op", "method_placement", "dropped_feature" +- "severity": "high" (changes model output), "medium" (changes training behavior), + or "low" (cosmetic or minor) +- "source_snippet": copy the exact line(s) verbatim from the PyTorch source + (max 3 lines). For missing components, show the class/function signature. +- "output_snippet": copy the exact line(s) verbatim from the JAX output + (max 3 lines). Use "MISSING" if the component does not exist. +- "corrected_snippet": the exact replacement code that should replace + output_snippet to fix the deviation. Use "ADD" for missing components + (and put the new code in the fix field). +- "fix": specific instruction for how to fix the deviation + +If there are NO deviations, return an empty array: [] + +Return ONLY the JSON array, no markdown formatting, no explanation. +""" + + +REPAIR_PROMPT = """You are an expert JAX/Flax developer. You have been given a +JAX/Flax code file that was converted from PyTorch, along with a list of +faithfulness deviations that need to be fixed. + +## Original PyTorch Source (for reference): +```python +{pytorch_code} +``` + +## Current JAX Code: +```python +{target_code} +``` +{rag_section} +## Deviations to Fix: +{deviations_text} + +## CRITICAL RULES: +1. For each deviation, find the EXACT output_snippet in the JAX code and + replace it with the corrected_snippet. If the snippets are not exact + matches (whitespace differences, etc.), locate the closest match and + apply the fix described in the instruction. +2. NEVER remove an existing class, function, method, or import -- even if it + seems unused or redundant. If the current JAX code has a class (e.g., + MoETrainer, MoEMetrics), it MUST remain in the output. +3. NEVER convert a class into standalone functions or vice versa. +4. NEVER remove a training loop, epoch loop, or any training utility code. +5. If a deviation's instruction says the current behavior is acceptable, + desirable, or "not recommended" to change, SKIP that deviation entirely. +6. Preserve ALL existing code structure -- only change what the deviation + specifically asks you to change. +7. The output must have the SAME number of classes and functions (or more) + as the input JAX code. + +Return ONLY the complete fixed Python code. No markdown formatting, no +explanation, no ```python blocks. +""" + + +_CODE_BLOCK_PATTERN = re.compile(r"```(?:python)?\n?(.*?)\n?```", re.DOTALL) + + +def _strip_markdown_formatting(text: str) -> str: + """Strips markdown and returns only the first Python code block.""" + code_block_match = _CODE_BLOCK_PATTERN.search(text) + if code_block_match: + return code_block_match.group(1).strip() + # Strip triple-quote wrappers the LLM may use instead of backticks. + stripped = text.strip() + if stripped.startswith('"""') and stripped.endswith('"""'): + return stripped[3:-3].strip() + return text + + +def _parse_json_response(text: str) -> list: + """Parse JSON from LLM response, handling markdown wrapping.""" + text = text.strip() + # Strip markdown code blocks if present + json_match = re.search(r"```(?:json)?\n?(.*?)\n?```", text, re.DOTALL) + if json_match: + text = json_match.group(1).strip() + try: + return json.loads(text) + except json.JSONDecodeError: + # Try to find a JSON array in the text + array_match = re.search(r'\[.*\]', text, re.DOTALL) + if array_match: + try: + return json.loads(array_match.group(0)) + except json.JSONDecodeError: + pass + return [] + + +class ValidationAgent(base.Agent): + """Agent for validating faithfulness of PyTorch-to-target conversions. + + This agent takes the original PyTorch source and the converted output, + identifies faithfulness deviations (changed defaults, wrong init, missing + components, altered semantics), and optionally repairs them. The prompt + pair is selected based on the configured `target` ("jax" or "maxtext"). + """ + + def __init__(self, model: Any, rag_agent_instance=None, target: str = "jax"): + """Initializes the agent. + + Args: + model: The LLM model to use for generation. + rag_agent_instance: Optional RAGAgent for retrieving context + during repair. If None, repair runs without RAG context. + target: Conversion target ("jax" or "maxtext"). Selects which + validation/repair prompt is used. + """ + super().__init__( + model=model, + agent_domain=utils.AgentDomain.MIGRATION, + agent_type=utils.AgentType.PRIMARY, + ) + self._rag_agent = rag_agent_instance + self._target = target + + def _validation_prompt(self) -> str: + """Returns the validation prompt template for the active target.""" + prompt = prompts.get_prompt("VALIDATION_PROMPT", self._target) + if prompt is None: + return VALIDATION_PROMPT + return prompt + + def _repair_prompt(self) -> str: + """Returns the repair prompt template for the active target.""" + prompt = prompts.get_prompt("REPAIR_PROMPT", self._target) + if prompt is None: + return REPAIR_PROMPT + return prompt + + def validate(self, pytorch_code: str, target_code: str = None, + jax_code: str = None) -> list: + """Validates the converted output against the PyTorch source. + + Args: + pytorch_code: The original PyTorch source code. + target_code: The converted code in the target framework. The + deprecated `jax_code` keyword is accepted as an alias for + one release. + jax_code: Deprecated alias for `target_code`. + + Returns: + A list of deviation dicts, each with category, severity, + source_snippet, output_snippet, corrected_snippet, and fix fields. + """ + if target_code is None: + target_code = jax_code + if target_code is None: + raise TypeError( + "ValidationAgent.validate requires `target_code` (or the" + " deprecated `jax_code`) argument." + ) + response = self.generate( + self._validation_prompt(), + {"pytorch_code": pytorch_code, "target_code": target_code}, + ) + return _parse_json_response(response) + + @staticmethod + def _filter_actionable(deviations: list) -> list: + """Filter out deviations that explicitly say not to fix.""" + skip_phrases = [ + "not recommended", + "desirable deviation", + "correct and desirable", + "overly complex", + "acceptable deviation", + ] + actionable = [] + for d in deviations: + fix_text = d.get("fix", "").lower() + if any(phrase in fix_text for phrase in skip_phrases): + continue + actionable.append(d) + return actionable + + @staticmethod + def _format_deviations_for_repair(deviations: list) -> str: + """Formats deviations as numbered find/replace blocks for repair. + + Falls back to old source_line/output_line fields if the new + source_snippet/output_snippet fields are absent. + + Args: + deviations: List of deviation dicts from validate(). + + Returns: + A formatted string with numbered find/replace blocks. + """ + blocks = [] + for i, d in enumerate(deviations, 1): + severity = d.get("severity", "medium") + category = d.get("category", "unknown") + source = d.get("source_snippet", d.get("source_line", "N/A")) + output = d.get("output_snippet", d.get("output_line", "N/A")) + corrected = d.get("corrected_snippet", "") + fix = d.get("fix", "") + + block = f"### Deviation {i} [{severity}] - {category}\n" + block += f"Source (PyTorch): {source}\n" + block += f"Find in output: {output}\n" + if output == "MISSING": + block += f"Source to convert: {source}\n" + if corrected and corrected not in ("ADD", "MISSING"): + block += f"Replace with: {corrected}\n" + block += f"Instruction: {fix}" + blocks.append(block) + return "\n\n".join(blocks) + + def _get_repair_rag_context(self, deviations: list) -> str: + """Retrieves RAG context relevant to the repair deviations. + + Builds a query from deviation categories and fix text, retrieves + top-k documents, and returns a formatted string for the prompt. + + Args: + deviations: List of deviation dicts from validate(). + + Returns: + A formatted RAG context string, or "" if no RAG agent. + """ + if not self._rag_agent: + return "" + + # Build query from deviation categories and fix descriptions + query_parts = [] + for d in deviations: + category = d.get("category", "") + fix = d.get("fix", "") + if category: + query_parts.append(category.replace("_", " ")) + if fix: + query_parts.append(fix) + query = " ".join(query_parts) + if not query.strip(): + return "" + + try: + docs = self._rag_agent.retrieve_context(query, top_k=15) + except Exception: + return "" + + if not docs: + return "" + + section = "\n## Reference Patterns (from RAG):\n" + for doc in docs: + name = doc.get("name", "unknown") + text = doc.get("text", "") + section += f"\n### {name}\n{text}\n" + return section + + def repair(self, target_code: str = None, deviations: list = None, + pytorch_code: str = "", jax_code: str = None) -> str: + """Repairs the converted code based on identified deviations. + + Args: + target_code: The converted code (JAX or MaxText) to repair. The + deprecated `jax_code` keyword is accepted for one release. + deviations: List of deviation dicts from validate(). + pytorch_code: The original PyTorch source for reference. + jax_code: Deprecated alias for `target_code`. + + Returns: + The repaired code in the target framework. + """ + if target_code is None: + target_code = jax_code + if target_code is None: + raise TypeError( + "ValidationAgent.repair requires `target_code` (or the" + " deprecated `jax_code`) argument." + ) + if deviations is None: + deviations = [] + + # Filter to only actionable deviations + actionable = self._filter_actionable(deviations) + if not actionable: + return target_code + + deviations_text = self._format_deviations_for_repair(actionable) + rag_section = self._get_repair_rag_context(actionable) + response = self.generate( + self._repair_prompt(), + { + "target_code": target_code, + "deviations_text": deviations_text, + "rag_section": rag_section, + "pytorch_code": pytorch_code, + }, + ) + repaired = _strip_markdown_formatting(response) + # If the repair returned empty or very short, return original + if len(repaired) < len(target_code) * 0.5: + return target_code + return repaired + + def run(self, pytorch_code: str, target_code: str = None, + jax_code: str = None) -> tuple: + """Validates and optionally repairs the conversion. + + Args: + pytorch_code: The original PyTorch source code. + target_code: The converted code (JAX or MaxText). + jax_code: Deprecated alias for `target_code`. + + Returns: + Tuple of (repaired_code, deviations_list). + """ + if target_code is None: + target_code = jax_code + deviations = self.validate(pytorch_code, target_code=target_code) + if deviations: + repaired_code = self.repair( + target_code=target_code, deviations=deviations, + pytorch_code=pytorch_code, + ) + return repaired_code, deviations + return target_code, [] diff --git a/MaxCode/agents/migration/verification_agent.py b/MaxCode/agents/migration/verification_agent.py new file mode 100644 index 0000000..cb7288f --- /dev/null +++ b/MaxCode/agents/migration/verification_agent.py @@ -0,0 +1,408 @@ +"""Verification agent for scoring PyTorch-to-JAX conversion quality. + +Produces a scorecard with two metrics: + - Completeness (AST-based, no LLM): compares classes, methods, and + standalone functions by name. + - Correctness (LLM-based, requires API key): runs ValidationAgent to + detect deviations and scores them with weighted penalties. +""" + +import ast +import re +from dataclasses import dataclass, field +from fnmatch import fnmatchcase + + +@dataclass +class VerificationResult: + """Result of verifying a conversion.""" + completeness: dict = field(default_factory=dict) # score, total, found, classes, methods, functions + correctness: dict | None = None # score, deviations, by_category, by_severity (None if no api_key) + overall: float = 0.0 + + +# Standard PyTorch -> JAX/Flax method renames. +METHOD_RENAMES = { + "__init__": {"setup", "__call__"}, + "forward": {"__call__"}, +} + +# Methods always inlined during conversion. +ALWAYS_INLINED = { + "reset_parameters", +} + +# Severity weights for correctness scoring. +SEVERITY_WEIGHTS = {"high": 5, "medium": 3, "low": 1} + +# Known false-positive (category, severity) pairs. +FALSE_POSITIVE_RULES = { + ("method_placement", "low"), + ("missing_component", "low"), + ("dropped_feature", "low"), +} + +# MaxText submodule → source component patterns delegated to built-ins. +_MAXTEXT_DELEGATION_MAP = { + "attentions": { + "classes": ["*Attention*", "*RotaryEmbedding*"], + "functions": ["rotate_half", "apply_rotary_pos_emb", "repeat_kv", + "*_attention_forward", "l2norm"], + }, + "normalizations": { + "classes": ["*RMSNorm*", "*LayerNorm*", "*GroupNorm*"], + "functions": [], + }, + "linears": { + "classes": ["*MLP*", "*FeedForward*", "*FFN*"], + "functions": [], + }, + "embeddings": { + "classes": ["*Embedding*"], + "functions": [], + }, + "moe": { + "classes": ["*Expert*", "*Router*", "*MoE*"], + "functions": ["load_balancing_loss_func", "*_balancing_loss*"], + }, + "decoders": { + "classes": ["*Model"], + "functions": [], + }, +} + +# Infrastructure classes always excluded for MaxText targets. +_INFRASTRUCTURE_PATTERNS = [ + "*PreTrainedModel", "*ForCausalLM", "*ForSequenceClassification", + "*ForTokenClassification", "*ForQuestionAnswering", + "*ForMultipleChoice", "*ForMaskedLM", "*ForConditionalGeneration", + "*Model", +] + +# PyTorch-specific function patterns always excluded for MaxText targets. +_PYTORCH_FUNC_PATTERNS = ["torch_*"] + + +class VerificationAgent: + """Scores the quality of a PyTorch-to-JAX conversion. + + The completeness check is pure AST (no LLM). The correctness check + delegates to ValidationAgent for deviation detection and applies + weighted scoring. + """ + + def __init__(self, model=None, target: str = "jax"): + """Initialize the verification agent. + + Args: + model: Optional LLM model instance for correctness checks. + If None, correctness scoring is skipped. + target: Conversion target ("jax" or "maxtext"). Threaded + through to the inner ValidationAgent so the correctness + check uses the right validation prompt. + """ + self._model = model + self._target = target + + @staticmethod + def filter_maxtext_delegated(src_components, output_code): + """Remove source components delegated to MaxText built-in primitives. + + Parses the output code for ``from maxtext.layers.`` and + ``import maxtext.layers.`` statements, then uses + ``_MAXTEXT_DELEGATION_MAP`` to identify which source classes and + functions are handled by those built-ins. Infrastructure classes + and PyTorch-specific functions are always excluded. + + Args: + src_components: dict with "classes" (name -> [methods]) and + "functions" (list) as returned by ``extract_components``. + output_code: The generated MaxText code string. + + Returns: + (filtered_components, delegated_info) where + *filtered_components* is a copy of *src_components* with delegated + entries removed, and *delegated_info* is a dict with keys + "classes", "functions", and "count". + """ + # 1. Detect which maxtext.layers submodules are imported. + imported_subs = set(re.findall( + r"(?:from\s+maxtext\.layers\s+import\s+|" + r"from\s+maxtext\.layers\.)" + r"(\w+)", + output_code, + )) + # Also catch `import maxtext.layers.` form. + imported_subs |= set(re.findall( + r"import\s+maxtext\.layers\.(\w+)", + output_code, + )) + + # 2. Collect glob patterns for delegated classes/functions. + class_patterns = list(_INFRASTRUCTURE_PATTERNS) + func_patterns = list(_PYTORCH_FUNC_PATTERNS) + for sub in imported_subs: + entry = _MAXTEXT_DELEGATION_MAP.get(sub) + if entry: + class_patterns.extend(entry["classes"]) + func_patterns.extend(entry["functions"]) + + def _matches(name, patterns): + return any(fnmatchcase(name, pat) for pat in patterns) + + # 3. Partition classes. + kept_classes = {} + delegated_classes = [] + delegated_method_count = 0 + for cls_name, methods in src_components["classes"].items(): + if _matches(cls_name, class_patterns): + delegated_classes.append(cls_name) + delegated_method_count += len(methods) + else: + kept_classes[cls_name] = methods + + # 4. Partition functions. + kept_funcs = [] + delegated_funcs = [] + for fn in src_components["functions"]: + if _matches(fn, func_patterns): + delegated_funcs.append(fn) + else: + kept_funcs.append(fn) + + filtered = { + "classes": kept_classes, + "functions": kept_funcs, + } + delegated_info = { + "classes": delegated_classes, + "functions": delegated_funcs, + "count": len(delegated_classes) + delegated_method_count + len(delegated_funcs), + } + return filtered, delegated_info + + @staticmethod + def extract_components(code): + """Parse Python code and return its classes, methods, and functions. + + Args: + code: Python source code string. + + Returns: + dict with keys "classes" (name -> [methods]) and "functions" (list). + """ + tree = ast.parse(code) + classes = {} + functions = [] + + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + methods = [ + n.name + for n in ast.iter_child_nodes(node) + if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) + ] + classes[node.name] = methods + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + functions.append(node.name) + + return {"classes": classes, "functions": functions} + + @staticmethod + def compute_completeness(source_components, output_components): + """Compare source and output components and return a completeness report. + + Returns: + dict with score, total, found, classes, methods, functions breakdown. + """ + src_classes = source_components["classes"] + out_classes = output_components["classes"] + + src_class_names = set(src_classes.keys()) + out_class_names = set(out_classes.keys()) + matched_classes = src_class_names & out_class_names + missing_classes = sorted(src_class_names - out_class_names) + + total_methods = 0 + found_methods = 0 + missing_methods = [] + + for cls in src_classes: + src_methods = set(src_classes[cls]) + total_methods += len(src_methods) + if cls in out_classes: + out_methods = set(out_classes[cls]) + has_call = "__call__" in out_methods + for m in sorted(src_methods): + if m in out_methods: + found_methods += 1 + elif m in METHOD_RENAMES and METHOD_RENAMES[m] & out_methods: + found_methods += 1 + elif m in ALWAYS_INLINED: + found_methods += 1 + elif has_call and m not in ("__init__", "forward"): + found_methods += 1 + else: + missing_methods.append(f"{cls}.{m}") + else: + for m in sorted(src_methods): + missing_methods.append(f"{cls}.{m}") + + src_funcs = set(source_components["functions"]) + out_funcs = set(output_components["functions"]) + matched_funcs = src_funcs & out_funcs + for f in src_funcs - matched_funcs: + if f in out_class_names: + matched_funcs = matched_funcs | {f} + missing_funcs = sorted(src_funcs - matched_funcs) + + total = len(src_class_names) + total_methods + len(src_funcs) + found = len(matched_classes) + found_methods + len(matched_funcs) + score = (found / total * 100) if total > 0 else 100.0 + + return { + "score": round(score, 1), + "total": total, + "found": found, + "classes": { + "total": len(src_class_names), + "found": len(matched_classes), + "missing": missing_classes, + }, + "methods": { + "total": total_methods, + "found": found_methods, + "missing": missing_methods, + }, + "functions": { + "total": len(src_funcs), + "found": len(matched_funcs), + "missing": missing_funcs, + }, + } + + @staticmethod + def compute_correctness(source_code, output_code, api_key, + total_components=0, model=None, target: str = "jax"): + """Run ValidationAgent and score the output. + + Args: + source_code: The PyTorch source code. + output_code: The converted output code (JAX or MaxText). + api_key: Google API key for the LLM. + total_components: Number of source components for budget scaling. + model: Optional pre-configured LLM model instance. If None, + creates a new GeminiTool with the given api_key. + target: Conversion target ("jax" or "maxtext"). Selects which + validation prompt the inner ValidationAgent uses. + + Returns: + dict with score, deviation_count, deviations, filtered_deviations, + by_category, by_severity. + """ + import models + from agents.migration.validation_agent import ValidationAgent + + if model is None: + model = models.GeminiTool( + model_name=models.GeminiModel.GEMINI_3_1_PRO_PREVIEW, + api_key=api_key, + ) + validator = ValidationAgent(model=model, target=target) + all_deviations = validator.validate(source_code, output_code) + + if not isinstance(all_deviations, list): + all_deviations = [] + + real = [] + filtered = [] + for d in all_deviations: + sev = d.get("severity", "low").lower() + cat = d.get("category", "unknown") + if (cat, sev) in FALSE_POSITIVE_RULES: + filtered.append(d) + else: + real.append(d) + + by_severity = {} + by_category = {} + penalty = 0 + + for d in real: + sev = d.get("severity", "low").lower() + cat = d.get("category", "unknown") + by_severity[sev] = by_severity.get(sev, 0) + 1 + by_category[cat] = by_category.get(cat, 0) + 1 + penalty += SEVERITY_WEIGHTS.get(sev, 1) + + if total_components <= 0: + try: + tree = ast.parse(source_code) + total_components = sum( + 1 for n in ast.iter_child_nodes(tree) + if isinstance(n, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) + ) + except SyntaxError: + total_components = 0 + + budget = total_components * SEVERITY_WEIGHTS["medium"] + if budget > 0: + score = max(0.0, (1.0 - penalty / budget) * 100.0) + else: + score = 100.0 if penalty == 0 else 0.0 + + return { + "score": round(score, 1), + "deviation_count": len(real), + "deviations": real, + "filtered_deviations": filtered, + "by_category": by_category, + "by_severity": by_severity, + } + + def verify(self, source_code, output_code, api_key=None): + """Run full verification (completeness + optional correctness). + + Args: + source_code: The PyTorch source code string. + output_code: The converted JAX output code string. + api_key: Optional Google API key. If provided (or if self._model + is set), runs correctness check. + + Returns: + VerificationResult with completeness, correctness, and overall score. + """ + src_components = self.extract_components(source_code) + out_components = self.extract_components(output_code) + + delegated = None + if self._target == "maxtext": + src_components, delegated = self.filter_maxtext_delegated( + src_components, output_code, + ) + + completeness = self.compute_completeness(src_components, out_components) + + if delegated: + completeness["delegated"] = delegated + + correctness = None + if api_key or self._model: + correctness = self.compute_correctness( + source_code, output_code, + api_key=api_key, + total_components=completeness["total"], + model=self._model, + target=self._target, + ) + + if correctness is not None: + overall = round((completeness["score"] + correctness["score"]) / 2, 1) + else: + overall = completeness["score"] + + return VerificationResult( + completeness=completeness, + correctness=correctness, + overall=overall, + ) diff --git a/MaxCode/examples/demo/step5_verify.py b/MaxCode/examples/demo/step5_verify.py new file mode 100644 index 0000000..35174c6 --- /dev/null +++ b/MaxCode/examples/demo/step5_verify.py @@ -0,0 +1,231 @@ +""" +Step 5: Verify the quality of a PyTorch-to-JAX conversion. + +This script produces a scorecard with two metrics: + + Completeness (AST-based, no LLM) + Parses both files and compares classes, methods, and standalone + functions by name. Score = matched / total source components. + + Correctness (LLM-based, requires GOOGLE_API_KEY) + Runs the ValidationAgent to detect deviations between the PyTorch + source and JAX output. Score = 100 minus weighted penalties + (high=5, medium=3, low=1 per deviation). + +Requires: + - Step 3 completed (merged model file created) + - Step 4 completed (JAX output file created) + - Optionally GOOGLE_API_KEY for the correctness check + +Usage: + python step5_verify.py +""" + +import json +import os +import sys + +from config import MERGED_FILE, MERGED_UTILS_FILE, OUTPUT_DIR, REPO_URL, setup + +# Add MaxCode to sys.path so agent imports work +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + +from agents.migration.verification_agent import VerificationAgent + + +# ------------------------------------------------------------------ +# Scorecard display +# ------------------------------------------------------------------ + +def print_scorecard(completeness, correctness=None): + """Print a formatted verification scorecard.""" + print() + print("=" * 50) + print(" Conversion Verification Scorecard") + print("=" * 50) + + c = completeness + print() + print(f" Completeness: {c['score']:.1f}% " + f"({c['found']}/{c['total']} components)") + print(f" Classes: {c['classes']['found']}/{c['classes']['total']}", end="") + if c["classes"]["missing"]: + print(f" (missing: {', '.join(c['classes']['missing'])})", end="") + print() + + print(f" Methods: {c['methods']['found']}/{c['methods']['total']}", end="") + if c["methods"]["missing"]: + shown = c["methods"]["missing"][:5] + extra = len(c["methods"]["missing"]) - len(shown) + print(f" (missing: {', '.join(shown)}", end="") + if extra > 0: + print(f" +{extra} more", end="") + print(")", end="") + print() + + print(f" Functions: {c['functions']['found']}/{c['functions']['total']}", end="") + if c["functions"]["missing"]: + print(f" (missing: {', '.join(c['functions']['missing'])})", end="") + print() + + if c.get("delegated"): + d = c["delegated"] + print(f" Delegated: {d['count']} components handled by MaxText built-ins") + + if correctness is not None: + cr = correctness + n_dev = cr["deviation_count"] + n_filt = len(cr.get("filtered_deviations", [])) + print() + print(f" Correctness: {cr['score']:.1f}% " + f"({n_dev} deviation{'s' if n_dev != 1 else ''} found" + f"{f', {n_filt} filtered' if n_filt else ''})") + for sev in ("high", "medium", "low"): + count = cr["by_severity"].get(sev, 0) + if count: + cats = [ + d.get("category", "unknown") + for d in cr["deviations"] + if d.get("severity", "").lower() == sev + ] + cat_str = ", ".join(sorted(set(cats))) + print(f" {sev:8s} {count} ({cat_str})") + else: + print() + print(" Correctness: skipped (GOOGLE_API_KEY not set)") + + if correctness is not None: + overall = round((completeness["score"] + correctness["score"]) / 2, 1) + else: + overall = completeness["score"] + print() + print(f" Overall: {overall:.1f}%") + print() + print("=" * 50) + + return overall + + +# ------------------------------------------------------------------ +# Main +# ------------------------------------------------------------------ + +def _find_jax_output(): + """Return the path to the JAX output file inside OUTPUT_DIR.""" + if not os.path.isdir(OUTPUT_DIR): + return None + repo_name = REPO_URL.rstrip("/").rsplit("/", 1)[-1].replace("-", "_") + expected = f"{repo_name}_jax.py" + expected_path = os.path.join(OUTPUT_DIR, expected) + if os.path.isfile(expected_path): + return expected_path + for name in os.listdir(OUTPUT_DIR): + if name.endswith("_jax.py"): + return os.path.join(OUTPUT_DIR, name) + return None + + +def main(): + setup() + + if not os.path.isfile(MERGED_FILE): + print("ERROR: Merged model file not found. Run step3_merge.py first.") + sys.exit(1) + + jax_path = _find_jax_output() + if jax_path is None: + print("ERROR: No JAX output file found in output/. Run step4_convert.py first.") + sys.exit(1) + + print("=" * 50) + print(" Step 5: Verify Conversion Quality") + print("=" * 50) + print(f" Source: {MERGED_FILE}") + print(f" Output: {jax_path}") + + # Read source and output + with open(MERGED_FILE, "r", encoding="utf-8") as f: + source_code = f.read() + with open(jax_path, "r", encoding="utf-8") as f: + output_code = f.read() + + # Run verification + api_key = os.environ.get("GOOGLE_API_KEY") + verifier = VerificationAgent() + + if api_key: + print("\n Running verification (completeness + correctness)...") + else: + print("\n GOOGLE_API_KEY not set -- running completeness check only.") + + result = verifier.verify(source_code, output_code, api_key=api_key) + overall = print_scorecard(result.completeness, result.correctness) + + # -- Utility file verification -- + utils_completeness = None + repo_name = REPO_URL.rstrip("/").rsplit("/", 1)[-1].replace("-", "_") + utils_jax_path = os.path.join(OUTPUT_DIR, f"{repo_name}_utils_jax.py") + + if os.path.isfile(MERGED_UTILS_FILE) and os.path.isfile(utils_jax_path): + print() + print("-" * 50) + print(" Utility File Verification") + print("-" * 50) + print(f" Source: {MERGED_UTILS_FILE}") + print(f" Output: {utils_jax_path}") + + with open(MERGED_UTILS_FILE, "r", encoding="utf-8") as f: + utils_source = f.read() + with open(utils_jax_path, "r", encoding="utf-8") as f: + utils_output = f.read() + + utils_result = verifier.verify(utils_source, utils_output) + utils_completeness = utils_result.completeness + + u = utils_completeness + print(f"\n Utility Completeness: {u['score']:.1f}% " + f"({u['found']}/{u['total']} components)") + print(f" Classes: {u['classes']['found']}/{u['classes']['total']}", end="") + if u["classes"]["missing"]: + print(f" (missing: {', '.join(u['classes']['missing'])})", end="") + print() + print(f" Functions: {u['functions']['found']}/{u['functions']['total']}", end="") + if u["functions"]["missing"]: + shown = u["functions"]["missing"][:5] + extra = len(u["functions"]["missing"]) - len(shown) + print(f" (missing: {', '.join(shown)}", end="") + if extra > 0: + print(f" +{extra} more", end="") + print(")", end="") + print() + elif os.path.isfile(MERGED_UTILS_FILE): + print("\n Utility JAX output not found -- skipping utility verification.") + + # -- Save JSON -- + os.makedirs(OUTPUT_DIR, exist_ok=True) + json_result = { + "source_file": MERGED_FILE, + "output_file": jax_path, + "completeness": result.completeness, + "overall": overall, + } + if result.correctness is not None: + json_result["correctness"] = { + "score": result.correctness["score"], + "deviation_count": result.correctness["deviation_count"], + "by_category": result.correctness["by_category"], + "by_severity": result.correctness["by_severity"], + "deviations": result.correctness["deviations"], + "filtered_deviations": result.correctness.get("filtered_deviations", []), + } + if utils_completeness is not None: + json_result["utils_completeness"] = utils_completeness + + json_path = os.path.join(OUTPUT_DIR, "verification_scorecard.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(json_result, f, indent=2) + print(f" Results saved to {json_path}") + + +if __name__ == "__main__": + main() diff --git a/MaxCode/tools/verification_tool.py b/MaxCode/tools/verification_tool.py new file mode 100644 index 0000000..1ad29f1 --- /dev/null +++ b/MaxCode/tools/verification_tool.py @@ -0,0 +1,69 @@ +"""Verification tool for ADK — scores PyTorch-to-JAX conversion quality.""" + +import json +import logging + +from agents.migration.verification_agent import VerificationAgent +from google.adk.tools.function_tool import FunctionTool + + +def verify_conversion( + source_path: str, + output_path: str, + api_key: str = "", +) -> str: + """Verify quality of a PyTorch-to-JAX conversion. + + Computes a completeness score (AST-based) and optionally a correctness + score (LLM-based, requires api_key). Returns JSON with both scores and + an overall score. + + Args: + source_path: Path to the original PyTorch source file. + output_path: Path to the converted JAX output file. + api_key: Optional Google AI API key for LLM-based correctness check. + + Returns: + A JSON string with completeness, correctness, and overall scores. + """ + logging.info( + "verify_conversion called with source_path=%s, output_path=%s", + source_path, output_path, + ) + + try: + with open(source_path, "r", encoding="utf-8") as f: + source_code = f.read() + except OSError as e: + return json.dumps({"error": f"Cannot read source file: {e}"}) + + try: + with open(output_path, "r", encoding="utf-8") as f: + output_code = f.read() + except OSError as e: + return json.dumps({"error": f"Cannot read output file: {e}"}) + + verifier = VerificationAgent() + result = verifier.verify( + source_code, output_code, + api_key=api_key if api_key else None, + ) + + response = { + "source_path": source_path, + "output_path": output_path, + "completeness": result.completeness, + "overall": result.overall, + } + if result.correctness is not None: + response["correctness"] = { + "score": result.correctness["score"], + "deviation_count": result.correctness["deviation_count"], + "by_category": result.correctness["by_category"], + "by_severity": result.correctness["by_severity"], + } + + return json.dumps(response) + + +verify_conversion_tool = FunctionTool(verify_conversion)