Skip to content

Commit 6e520dd

Browse files
committed
Stabilize Gemma 4 MLX constant slot ordering
1 parent 818a51d commit 6e520dd

1 file changed

Lines changed: 40 additions & 16 deletions

File tree

backends/mlx/builder/program_builder.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -444,26 +444,50 @@ def _make_io_slots(self): # noqa: C901
444444
else:
445445
raise NotImplementedError(f"Support for input {arg} is not implemented")
446446

447+
placeholder_nodes = {
448+
node.name: node for node in self.ep.graph.nodes if node.op == "placeholder"
449+
}
450+
451+
# Allocate placeholder-backed slots in graph-signature order instead of
452+
# raw FX node traversal order. This keeps lifted constant tids stable
453+
# across equivalent exports, which matters for models like Gemma 4 that
454+
# carry multiple rotary constant placeholders with similar structure.
455+
for name in constant_tensors:
456+
node = placeholder_nodes.get(name)
457+
if node is None or node.users == {}:
458+
continue
459+
self.make_or_get_slot(node, id_space=IdSpace.Constant)
460+
461+
for name in user_inputs:
462+
node = placeholder_nodes.get(name)
463+
if node is None or node.users == {}:
464+
continue
465+
val = node.meta.get("val", None)
466+
if isinstance(val, torch.Tensor) and not val.is_contiguous():
467+
raise ValueError(
468+
f"MLX backend requires contiguous input tensors, "
469+
f"but input '{node.name}' has non-contiguous strides. "
470+
f"shape={list(val.shape)}, stride={list(val.stride())}. "
471+
f"Ensure example inputs passed to torch.export.export() "
472+
f"are contiguous (call .contiguous() on them)."
473+
)
474+
self.make_or_get_slot(node, id_space=IdSpace.Input)
475+
476+
for name in mutable_buffers:
477+
node = placeholder_nodes.get(name)
478+
if node is None or node.users == {}:
479+
continue
480+
self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer)
481+
482+
classified_placeholders = (
483+
set(constant_tensors) | set(user_inputs) | set(mutable_buffers)
484+
)
485+
447486
for node in self.ep.graph.nodes:
448487
if node.op == "placeholder":
449488
if node.users == {}:
450489
continue
451-
if node.name in constant_tensors:
452-
self.make_or_get_slot(node, id_space=IdSpace.Constant)
453-
elif node.name in user_inputs:
454-
val = node.meta.get("val", None)
455-
if isinstance(val, torch.Tensor) and not val.is_contiguous():
456-
raise ValueError(
457-
f"MLX backend requires contiguous input tensors, "
458-
f"but input '{node.name}' has non-contiguous strides. "
459-
f"shape={list(val.shape)}, stride={list(val.stride())}. "
460-
f"Ensure example inputs passed to torch.export.export() "
461-
f"are contiguous (call .contiguous() on them)."
462-
)
463-
self.make_or_get_slot(node, id_space=IdSpace.Input)
464-
elif node.name in mutable_buffers:
465-
self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer)
466-
else:
490+
if node.name not in classified_placeholders:
467491
raise NotImplementedError(
468492
f"Support for placeholder {node.name} is not implemented"
469493
)

0 commit comments

Comments
 (0)