Skip to content

Commit 92c541a

Browse files
committed
up
1 parent 2ca57a7 commit 92c541a

2 files changed

Lines changed: 23 additions & 64 deletions

File tree

backends/mlx/builder/slot_manager.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,25 @@ class IdSpace(Enum):
3030
Temp = auto()
3131

3232

33-
@dataclass(frozen=True)
33+
@dataclass(eq=False, frozen=True)
3434
class Slot:
35+
"""Represents an allocated tensor or symbolic int slot.
36+
37+
Uses identity-based equality and hashing (not field-based) so that
38+
two Slots with the same (id_type, id_space, idx) — which can happen
39+
when the delete-as-you-go allocator recycles an idx — remain distinct
40+
in sets and dicts during build().
41+
"""
42+
3543
id_type: IdType
3644
id_space: IdSpace
3745
idx: Optional[int] = None
38-
# Unique allocation ID — ensures Slots with the same (id_type, id_space, idx)
39-
# remain distinct in sets/dicts after an idx is freed and reused.
40-
# Without this, the delete-as-you-go allocator can free idx=5, then
41-
# make_tmp_slot reuses idx=5, and the new Slot equals the old one.
42-
# In build()'s _collect_used_slots (a set) and _create_slot_mappings (a dict),
43-
# they merge into one entry and get the same global Tid — causing aliasing.
44-
alloc_id: Optional[int] = None
46+
47+
def __eq__(self, other):
48+
return self is other
49+
50+
def __hash__(self):
51+
return id(self)
4552

4653

4754
class IdManager:
@@ -66,13 +73,6 @@ def __init__(self):
6673
self.tid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager)
6774
self.vid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager)
6875
self.name_to_slot: Dict[str, Slot] = {}
69-
self._next_alloc_id: int = 0
70-
71-
def _alloc_id(self) -> int:
72-
"""Return a globally unique allocation ID."""
73-
aid = self._next_alloc_id
74-
self._next_alloc_id += 1
75-
return aid
7676

7777
def set_slot(self, node_or_name: Union[Node, str], slot: Slot):
7878
if isinstance(node_or_name, Node):
@@ -124,9 +124,7 @@ def make_constant_slot(self, name: str) -> Slot:
124124
id_space = IdSpace.Constant
125125
manager = self.tid_managers[id_space]
126126
idx = manager.get_id()
127-
slot = Slot(
128-
id_type=IdType.Tensor, id_space=id_space, idx=idx, alloc_id=self._alloc_id()
129-
)
127+
slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx)
130128
self.name_to_slot[name] = slot
131129
return slot
132130

@@ -135,9 +133,7 @@ def make_tmp_slot(self) -> Tuple[str, Slot]:
135133
id_space = IdSpace.Temp
136134
manager = self.tid_managers[id_space]
137135
idx = manager.get_id()
138-
slot = Slot(
139-
id_type=IdType.Tensor, id_space=id_space, idx=idx, alloc_id=self._alloc_id()
140-
)
136+
slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx)
141137
self.name_to_slot[name] = slot
142138
return name, slot
143139

@@ -147,9 +143,7 @@ def make_tmp_value_slot(self) -> Tuple[str, Slot]:
147143
id_space = IdSpace.Temp
148144
manager = self.vid_managers[id_space]
149145
idx = manager.get_id()
150-
slot = Slot(
151-
id_type=IdType.SymInt, id_space=id_space, idx=idx, alloc_id=self._alloc_id()
152-
)
146+
slot = Slot(id_type=IdType.SymInt, id_space=id_space, idx=idx)
153147
self.name_to_slot[name] = slot
154148
return name, slot
155149

@@ -182,14 +176,7 @@ def make_or_get_slots(
182176
else:
183177
manager = self.vid_managers[id_space]
184178
idx = manager.get_id()
185-
slots.append(
186-
Slot(
187-
id_type=id_type,
188-
id_space=id_space,
189-
idx=idx,
190-
alloc_id=self._alloc_id(),
191-
)
192-
)
179+
slots.append(Slot(id_type=id_type, id_space=id_space, idx=idx))
193180
slots = tuple(slots)
194181

195182
# Store in the format that matches the node's output structure

backends/mlx/model_ops/gated_delta_rule.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@
3434
from torch.fx.node import Node
3535

3636

37-
# ---------------------------------------------------------------------------
38-
# Custom op definition
39-
# ---------------------------------------------------------------------------
40-
41-
4237
@torch.library.custom_op("mlx::gated_delta_rule", mutates_args=("state",))
4338
def gated_delta_rule(
4439
q: Tensor, # [B, T, Hk, Dk]
@@ -96,11 +91,6 @@ def gated_delta_rule_fake(
9691

9792

9893
from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type
99-
100-
# ---------------------------------------------------------------------------
101-
# Pattern handler
102-
# ---------------------------------------------------------------------------
103-
10494
from executorch.backends.mlx.builder.op_registry import PatternHandler, REGISTRY
10595
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
10696
from executorch.backends.mlx.builder.slot_manager import Slot
@@ -311,12 +301,9 @@ def _emit_metal_kernel(self, P: MLXProgramBuilder, n: Node) -> Slot:
311301
b_iov = P.to_int_or_vid(b_val)
312302
t_iov = P.to_int_or_vid(t_val)
313303

314-
# Output slot for y
315-
existing = P.slot_manager.get_slot(self.getitem_0)
316-
if existing is not None:
317-
out = existing if not isinstance(existing, tuple) else existing[0]
318-
else:
319-
_, out = P.make_tmp_slot()
304+
# Output slot for y — use existing IO slot if getitem_0 is a graph output,
305+
# otherwise create a new temp slot.
306+
out = P.make_or_get_slot(self.getitem_0)
320307

321308
# Output slot for state_out (carry)
322309
_, carry = P.make_tmp_slot()
@@ -449,9 +436,6 @@ def _emit_metal_kernel(self, P: MLXProgramBuilder, n: Node) -> Slot:
449436
def _emit_scan(self, P: MLXProgramBuilder, n: Node) -> Slot:
450437
"""Emit ScanNode decomposition of the gated delta recurrence."""
451438

452-
# With alloc_id on Slot, slot_map's _mark_read can safely free
453-
# and reuse idx values — each allocation remains distinct in
454-
# build()'s used_slots set and _create_slot_mappings dict.
455439
q_slot, k_slot, v_slot, g_slot, beta_slot, state_slot = P.slot_map(
456440
[
457441
self.q_node,
@@ -475,15 +459,7 @@ def _emit_scan(self, P: MLXProgramBuilder, n: Node) -> Slot:
475459
_, beta_s = P.make_tmp_slot()
476460

477461
# Output slot for the recurrence output.
478-
# getitem_0 already has an Output slot from _make_io_slots (it's a
479-
# USER_OUTPUT). Use that existing slot so the ScanNode writes directly
480-
# into the output slot. Don't call make_or_get_slots on auto_func_node
481-
# (deferred body node must not have slots per _verify_build).
482-
existing = P.slot_manager.get_slot(self.getitem_0)
483-
if existing is not None:
484-
out = existing if not isinstance(existing, tuple) else existing[0]
485-
else:
486-
_, out = P.make_tmp_slot()
462+
out = P.make_or_get_slot(self.getitem_0)
487463

488464
# Body temp slots
489465
_, t0 = P.make_tmp_slot()
@@ -575,10 +551,6 @@ def _emit_scan(self, P: MLXProgramBuilder, n: Node) -> Slot:
575551
return carry
576552

577553

578-
# ---------------------------------------------------------------------------
579-
# Registration
580-
# ---------------------------------------------------------------------------
581-
582554
_registered = False
583555

584556

0 commit comments

Comments
 (0)