-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmaxtext_conversion_agent.py
More file actions
366 lines (315 loc) · 12.3 KB
/
maxtext_conversion_agent.py
File metadata and controls
366 lines (315 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""Agent for converting a PyTorch model into a MaxText artifact set.
The MaxText conversion is staged across several LLM calls so each one stays
focused: classify the source, emit a YAML config overlay, optionally emit a
custom layers `.py` file, and best-effort emit a checkpoint converter. The
canonical MaxText deliverable is a YAML overlay; the layers and converter
artifacts are produced only when the source warrants them.
"""
from __future__ import annotations
import ast
import json
import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from agents import base
from agents import utils
from agents.migration.model_conversion_agent import _strip_markdown_formatting
from agents.migration.prompts import prompts
from rag import rag_agent
logger = logging.getLogger(__name__)
# Decoder block families recognised by MaxText. The classifier is asked to
# pick from this list (or "custom" if nothing fits). Kept here so the agent
# can defensively map LLM output onto the canonical set.
_KNOWN_DECODER_BLOCKS = (
"llama2", "llama3", "llama4",
"gemma", "gemma2", "gemma3",
"mistral", "mixtral",
"qwen3", "qwen3_next",
"deepseek2", "deepseek3",
"gpt_oss", "kimi",
"default", "custom",
)
# Decoder blocks classified as known families but lacking a built-in
# MaxText JAX implementation. These always get a layers file emitted.
_FORCE_LAYERS_BLOCKS = frozenset({"qwen3_next"})
@dataclass
class MaxTextArtifacts:
"""Paths (and metadata) for the artifacts produced by the MaxText path.
All path fields are populated by the persistence layer in `interface/api.py`
after this agent has produced the corresponding string content; the agent
itself only fills `decoder_block` and the in-memory artifact bodies via
`MaxTextRunResult` below.
"""
config_yaml_path: str
layers_py_path: Optional[str] = None
ckpt_converter_path: Optional[str] = None
decoder_block: str = "default"
@dataclass
class MaxTextRunResult:
"""In-memory result of a MaxTextConversionAgent run.
Holds the raw content of every artifact plus the classification metadata.
The persistence layer turns this into a `MaxTextArtifacts` instance and a
flat string-to-string `results` dict for the standard write path.
"""
decoder_block: str
justification: str
config_yaml: str
layers_py: Optional[str] = None
ckpt_converter_py: Optional[str] = None
model_name: str = "model"
def _strip_yaml_formatting(text: str) -> str:
"""Strips markdown fences from a YAML response."""
match = re.search(r"```(?:yaml|yml)?\n?(.*?)\n?```", text, re.DOTALL)
if match:
return match.group(1).strip()
stripped = text.strip()
if stripped.startswith("```"):
first_nl = stripped.find("\n")
if first_nl != -1:
stripped = stripped[first_nl + 1:]
if stripped.endswith("```"):
stripped = stripped[:-3]
return stripped.strip()
return stripped
def _extract_dim_hints(pytorch_code: str) -> Dict[str, Any]:
"""Best-effort scan for common config attributes on a PyTorch config class.
Walks AST assignments that look like `self.<name> = <Constant>` inside any
`__init__`. Used purely as a hint passed to the YAML prompt — the LLM is
still expected to verify the values against the source. Returns an empty
dict on parse failure.
"""
hints: Dict[str, Any] = {}
try:
tree = ast.parse(pytorch_code)
except SyntaxError:
return hints
interesting = {
"hidden_size", "num_attention_heads", "num_key_value_heads",
"num_hidden_layers", "vocab_size", "intermediate_size",
"head_dim", "max_position_embeddings", "rms_norm_eps",
"rope_theta", "tie_word_embeddings", "num_experts",
"num_experts_per_tok", "moe_intermediate_size",
"router_aux_loss_coef", "n_routed_experts", "n_shared_experts",
"first_k_dense_replace", "moe_layer_freq",
}
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
if len(node.targets) != 1:
continue
tgt = node.targets[0]
if not isinstance(tgt, ast.Attribute):
continue
if not (isinstance(tgt.value, ast.Name) and tgt.value.id == "self"):
continue
if tgt.attr not in interesting:
continue
if isinstance(node.value, ast.Constant):
# Don't overwrite an earlier sighting — first wins (usually the
# default-bearing assignment).
hints.setdefault(tgt.attr, node.value.value)
# Dataclass-style class-level annotations: `hidden_size: int = 2048`
if isinstance(node, ast.AnnAssign):
if isinstance(node.target, ast.Name) and node.target.id in interesting:
if node.value and isinstance(node.value, ast.Constant):
hints.setdefault(node.target.id, node.value.value)
return hints
def _format_dim_hints(hints: Dict[str, Any]) -> str:
"""Pretty-prints the dim hints for the prompt body."""
if not hints:
return "(no hints extracted; derive everything from the source)"
return "\n".join(f"- {k}: {v}" for k, v in sorted(hints.items()))
def _normalize_decoder_block(value: str) -> str:
"""Snaps an LLM-emitted decoder block onto the known set.
Accepts variants like "Llama-2", "llama_2", "Llama2" and maps them onto
the canonical "llama2".
"""
if not value:
return "default"
v = value.strip().lower().replace("-", "_")
if v in _KNOWN_DECODER_BLOCKS:
return v
# Compare with all separators stripped so "llama-2", "llama_2", "llama 2"
# all match the canonical "llama2".
v_compact = v.replace("_", "")
for known in _KNOWN_DECODER_BLOCKS:
if v_compact == known.replace("_", ""):
return known
return "custom"
def _parse_classification(text: str) -> Dict[str, str]:
"""Parses the classifier's JSON response. Falls back to {custom, ""} on error."""
raw = text.strip()
json_match = re.search(r"```(?:json)?\n?(.*?)\n?```", raw, re.DOTALL)
if json_match:
raw = json_match.group(1).strip()
try:
obj = json.loads(raw)
except json.JSONDecodeError:
obj_match = re.search(r"\{.*\}", raw, re.DOTALL)
if not obj_match:
logger.warning("Classifier returned unparseable response; defaulting to 'custom'")
return {"decoder_block": "custom", "justification": ""}
try:
obj = json.loads(obj_match.group(0))
except json.JSONDecodeError:
logger.warning("Classifier JSON sub-extract failed; defaulting to 'custom'")
return {"decoder_block": "custom", "justification": ""}
if not isinstance(obj, dict):
return {"decoder_block": "custom", "justification": ""}
return {
"decoder_block": _normalize_decoder_block(str(obj.get("decoder_block", ""))),
"justification": str(obj.get("justification", "")),
}
def _format_rag_context(docs: List[Dict[str, Any]]) -> str:
"""Formats RAG docs for inclusion in a prompt body."""
if not docs:
return "(no reference snippets available)"
blocks = []
for d in docs:
name = d.get("name", "unknown")
text = d.get("text", "")
blocks.append(f"### {name}\n```python\n{text}\n```")
return "\n\n".join(blocks)
class MaxTextConversionAgent(base.Agent):
"""Stages classify -> YAML -> (layers) -> (ckpt converter) for MaxText output."""
def __init__(
self,
model: Any,
rag_agent_instance: rag_agent.RAGAgent,
):
"""Initializes the agent.
Args:
model: The LLM model.
rag_agent_instance: RAG agent (expected to have `target='maxtext'`).
"""
super().__init__(
model=model,
agent_domain=utils.AgentDomain.MIGRATION,
agent_type=utils.AgentType.MODEL_CONVERSION,
)
self._rag_agent = rag_agent_instance
# ---- Stage 1: classify -------------------------------------------------
def _classify(self, pytorch_code: str) -> Dict[str, str]:
"""Picks the closest existing MaxText `decoder_block` for the source."""
docs = self._rag_agent.retrieve_per_component_context(pytorch_code)
rag_context = _format_rag_context(docs)
prompt = prompts.get_prompt("MAXTEXT_CLASSIFY_PROMPT", "maxtext")
response = self.generate(
prompt,
{"pytorch_code": pytorch_code, "rag_context": rag_context},
)
return _parse_classification(response)
# ---- Stage 2: YAML overlay --------------------------------------------
def _emit_yaml(
self,
pytorch_code: str,
decoder_block: str,
justification: str,
) -> str:
"""Emits the YAML config overlay for `MaxText/configs/models/`."""
docs = self._rag_agent.retrieve_context(
f"MaxText config overlay {decoder_block}", top_k=10
)
rag_context = _format_rag_context(docs)
dim_hints = _format_dim_hints(_extract_dim_hints(pytorch_code))
prompt = prompts.get_prompt("MAXTEXT_YAML_PROMPT", "maxtext")
response = self.generate(
prompt,
{
"pytorch_code": pytorch_code,
"rag_context": rag_context,
"decoder_block": decoder_block,
"justification": justification,
"dim_hints": dim_hints,
},
)
return _strip_yaml_formatting(response)
# ---- Stage 3 (conditional): custom layers file ------------------------
def _emit_layers(
self,
pytorch_code: str,
justification: str,
) -> Optional[str]:
"""Emits a small layers `.py` file when the architecture is custom."""
docs = self._rag_agent.retrieve_per_component_context(pytorch_code)
rag_context = _format_rag_context(docs)
prompt = prompts.get_prompt("MAXTEXT_LAYERS_PROMPT", "maxtext")
response = self.generate(
prompt,
{
"pytorch_code": pytorch_code,
"rag_context": rag_context,
"justification": justification,
"maxtext_best_practices": prompts.MAXTEXT_BEST_PRACTICES,
},
)
code = _strip_markdown_formatting(response)
if not code or len(code.strip()) < 40:
logger.warning("MaxText layers stage returned suspiciously short output; skipping")
return None
return code
# ---- Stage 4 (best-effort): checkpoint converter -----------------------
def _emit_ckpt_converter(
self,
pytorch_code: str,
decoder_block: str,
yaml_config: str,
) -> Optional[str]:
"""Best-effort: emit a HF/PyTorch -> Orbax converter. Errors are swallowed."""
try:
docs = self._rag_agent.retrieve_context(
f"MaxText checkpoint converter {decoder_block} state dict orbax",
top_k=8,
)
rag_context = _format_rag_context(docs)
prompt = prompts.get_prompt("MAXTEXT_CKPT_CONVERTER_PROMPT", "maxtext")
response = self.generate(
prompt,
{
"pytorch_code": pytorch_code,
"rag_context": rag_context,
"decoder_block": decoder_block,
"yaml_config": yaml_config,
},
)
code = _strip_markdown_formatting(response)
if not code or len(code.strip()) < 60:
logger.info("Checkpoint converter stage returned trivial output; skipping")
return None
return code
except Exception as e:
logger.warning("Checkpoint converter stage failed (best-effort): %s", e)
return None
# ---- Orchestration -----------------------------------------------------
def run(
self,
pytorch_code: str,
model_name: str = "model",
) -> MaxTextRunResult:
"""Runs all stages and returns the populated `MaxTextRunResult`.
Args:
pytorch_code: The merged or single-file PyTorch source.
model_name: Stem used for output filenames (e.g. "qwen3_next").
Returns:
A `MaxTextRunResult` ready for persistence.
"""
cls = self._classify(pytorch_code)
decoder_block = cls["decoder_block"]
justification = cls["justification"]
logger.info("MaxText classification: decoder_block=%s, justification=%s",
decoder_block, justification)
yaml_config = self._emit_yaml(pytorch_code, decoder_block, justification)
layers_py: Optional[str] = None
if decoder_block == "custom" or decoder_block in _FORCE_LAYERS_BLOCKS:
layers_py = self._emit_layers(pytorch_code, justification)
ckpt_converter_py = self._emit_ckpt_converter(
pytorch_code, decoder_block, yaml_config
)
return MaxTextRunResult(
decoder_block=decoder_block,
justification=justification,
config_yaml=yaml_config,
layers_py=layers_py,
ckpt_converter_py=ckpt_converter_py,
model_name=model_name,
)