@@ -444,26 +444,50 @@ def _make_io_slots(self): # noqa: C901
444444 else :
445445 raise NotImplementedError (f"Support for input { arg } is not implemented" )
446446
447+ placeholder_nodes = {
448+ node .name : node for node in self .ep .graph .nodes if node .op == "placeholder"
449+ }
450+
451+ # Allocate placeholder-backed slots in graph-signature order instead of
452+ # raw FX node traversal order. This keeps lifted constant tids stable
453+ # across equivalent exports, which matters for models like Gemma 4 that
454+ # carry multiple rotary constant placeholders with similar structure.
455+ for name in constant_tensors :
456+ node = placeholder_nodes .get (name )
457+ if node is None or node .users == {}:
458+ continue
459+ self .make_or_get_slot (node , id_space = IdSpace .Constant )
460+
461+ for name in user_inputs :
462+ node = placeholder_nodes .get (name )
463+ if node is None or node .users == {}:
464+ continue
465+ val = node .meta .get ("val" , None )
466+ if isinstance (val , torch .Tensor ) and not val .is_contiguous ():
467+ raise ValueError (
468+ f"MLX backend requires contiguous input tensors, "
469+ f"but input '{ node .name } ' has non-contiguous strides. "
470+ f"shape={ list (val .shape )} , stride={ list (val .stride ())} . "
471+ f"Ensure example inputs passed to torch.export.export() "
472+ f"are contiguous (call .contiguous() on them)."
473+ )
474+ self .make_or_get_slot (node , id_space = IdSpace .Input )
475+
476+ for name in mutable_buffers :
477+ node = placeholder_nodes .get (name )
478+ if node is None or node .users == {}:
479+ continue
480+ self .make_or_get_slot (node , id_space = IdSpace .MutableBuffer )
481+
482+ classified_placeholders = (
483+ set (constant_tensors ) | set (user_inputs ) | set (mutable_buffers )
484+ )
485+
447486 for node in self .ep .graph .nodes :
448487 if node .op == "placeholder" :
449488 if node .users == {}:
450489 continue
451- if node .name in constant_tensors :
452- self .make_or_get_slot (node , id_space = IdSpace .Constant )
453- elif node .name in user_inputs :
454- val = node .meta .get ("val" , None )
455- if isinstance (val , torch .Tensor ) and not val .is_contiguous ():
456- raise ValueError (
457- f"MLX backend requires contiguous input tensors, "
458- f"but input '{ node .name } ' has non-contiguous strides. "
459- f"shape={ list (val .shape )} , stride={ list (val .stride ())} . "
460- f"Ensure example inputs passed to torch.export.export() "
461- f"are contiguous (call .contiguous() on them)."
462- )
463- self .make_or_get_slot (node , id_space = IdSpace .Input )
464- elif node .name in mutable_buffers :
465- self .make_or_get_slot (node , id_space = IdSpace .MutableBuffer )
466- else :
490+ if node .name not in classified_placeholders :
467491 raise NotImplementedError (
468492 f"Support for placeholder { node .name } is not implemented"
469493 )
0 commit comments