Skip to content

Commit 240b241

Browse files
committed
up
1 parent ffeeafc commit 240b241

4 files changed

Lines changed: 7 additions & 197 deletions

File tree

backends/mlx/llm/cache.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def update(
227227
torch._check(seq_len <= self.window_size)
228228
else:
229229
start_pos = input_pos
230-
seq_len = k_val.size(2)
231230

232231
torch.ops.mlx.kv_cache_update(
233232
self.k_cache, k_val, start_pos, ring_size=self.buffer_size

backends/mlx/llm/et_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
mlx::custom_sdpa for efficient execution on Apple Silicon.
1414
1515
Usage:
16-
import executorch.backends.mlx.examples.et_attention # noqa: F401
16+
import executorch.backends.mlx.llm.et_attention # noqa: F401
1717
1818
model_args = ModelArgs(attention_type="mlx", ...)
1919
transformer = construct_transformer(model_args)

backends/mlx/ops.py

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,26 +2093,13 @@ def _relu_handler(P: MLXProgramBuilder, n: Node) -> Slot:
20932093
require_kwargs(P.kwargs(n), set(), "aten.relu")
20942094
(x,) = args # x is already a Slot
20952095

2096-
# Get input dtype
20972096
x_meta = n.args[0].meta.get("val")
20982097
if x_meta is None:
20992098
raise ValueError("Input tensor metadata not found for relu")
21002099
dtype = x_meta.dtype
21012100

2102-
_, zero_slot = P.make_tmp_slot()
2101+
zero_slot = emit_lifted_constant(P, 0.0, dtype)
21032102

2104-
# Emit FullNode to create a scalar zero (shape=[])
2105-
# Maximum will broadcast this scalar to match input shape
2106-
P.emit(
2107-
FullNode(
2108-
shape=[], # Scalar (will be broadcast in maximum)
2109-
v=FloatOrVid.from_literal(0.0),
2110-
scalar_type=torch_dtype_to_scalar_type(dtype),
2111-
out=P.slot_to_tid(zero_slot),
2112-
)
2113-
)
2114-
2115-
# Emit MaximumNode(x, scalar_zero)
21162103
out = P.make_or_get_slot(n)
21172104
P.emit(
21182105
MaximumNode(
@@ -2264,26 +2251,12 @@ def _clamp_handler(P: MLXProgramBuilder, n: Node) -> Slot:
22642251
)
22652252
return out
22662253

2267-
# Helper to create a scalar constant slot
2268-
def make_scalar_slot(val):
2269-
_, slot = P.make_tmp_slot()
2270-
P.emit(
2271-
FullNode(
2272-
shape=[], # Scalar
2273-
v=FloatOrVid.from_literal(float(val)),
2274-
scalar_type=torch_dtype_to_scalar_type(dtype),
2275-
out=P.slot_to_tid(slot),
2276-
)
2277-
)
2278-
return slot
2279-
22802254
current = x
22812255

22822256
# Apply max constraint first: min(x, max_val)
22832257
if max_val is not None:
2284-
max_slot = make_scalar_slot(max_val)
2258+
max_slot = emit_lifted_constant(P, float(max_val), dtype)
22852259
if min_val is not None:
2286-
# Need a temp slot since we have both constraints
22872260
_, tmp = P.make_tmp_slot()
22882261
P.emit(
22892262
MinimumNode(
@@ -2294,7 +2267,6 @@ def make_scalar_slot(val):
22942267
)
22952268
current = tmp
22962269
else:
2297-
# Only max constraint, output directly
22982270
P.emit(
22992271
MinimumNode(
23002272
a=P.slot_to_tid(current),
@@ -2306,7 +2278,7 @@ def make_scalar_slot(val):
23062278

23072279
# Apply min constraint: max(current, min_val)
23082280
if min_val is not None:
2309-
min_slot = make_scalar_slot(min_val)
2281+
min_slot = emit_lifted_constant(P, float(min_val), dtype)
23102282
P.emit(
23112283
MaximumNode(
23122284
a=P.slot_to_tid(current),
@@ -2416,16 +2388,7 @@ def reshape_for_broadcast(slot, name_suffix):
24162388
)
24172389

24182390
# Step 2: var_eps = var + eps
2419-
# Create eps as a scalar using FullNode (broadcasts correctly with var)
2420-
_, eps_slot = P.make_tmp_slot()
2421-
P.emit(
2422-
FullNode(
2423-
out=P.slot_to_tid(eps_slot),
2424-
shape=[], # 0-D scalar
2425-
v=FloatOrVid.from_literal(float(eps)),
2426-
scalar_type=torch_dtype_to_scalar_type(torch.float32),
2427-
)
2428-
)
2391+
eps_slot = emit_lifted_constant(P, float(eps), torch.float32)
24292392
_, tmp_var_eps = P.make_tmp_slot()
24302393
P.emit(
24312394
AddNode(
@@ -2944,12 +2907,7 @@ def _logical_or_handler(P: MLXProgramBuilder, n: Node) -> Slot:
29442907

29452908
@REGISTRY.register(target=[torch.ops.aten.scalar_tensor.default])
29462909
def _scalar_tensor_handler(P: MLXProgramBuilder, n: Node) -> Slot:
2947-
"""Handle aten.scalar_tensor - create a 0-D tensor from a scalar value.
2948-
2949-
scalar_tensor(scalar, *, dtype=None, layout=None, device=None, pin_memory=None) -> Tensor
2950-
2951-
This is equivalent to torch.full([], scalar, dtype=dtype).
2952-
"""
2910+
"""This is equivalent to torch.full([], scalar, dtype=dtype)."""
29532911
args = P.args(n)
29542912
kwargs = P.kwargs(n)
29552913
require_args(args, 1, 1, "aten.scalar_tensor")
@@ -3438,20 +3396,9 @@ def _pow_handler(P: MLXProgramBuilder, n: Node) -> Slot:
34383396

34393397
# Handle scalar exponent by creating a scalar full tensor that will broadcast
34403398
if not isinstance(b, Slot):
3441-
# Get dtype from input tensor's meta
34423399
input_meta = n.args[0].meta.get("val")
34433400
dtype = input_meta.dtype if input_meta is not None else torch.float32
3444-
3445-
_, b_slot = P.make_tmp_slot()
3446-
P.emit(
3447-
FullNode(
3448-
out=P.slot_to_tid(b_slot),
3449-
shape=[], # 0-D scalar - broadcasts correctly
3450-
v=FloatOrVid.from_literal(float(b)),
3451-
scalar_type=torch_dtype_to_scalar_type(dtype),
3452-
)
3453-
)
3454-
b = b_slot
3401+
b = emit_lifted_constant(P, float(b), dtype)
34553402

34563403
out = P.make_or_get_slot(n)
34573404
P.emit(PowerNode(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out)))

0 commit comments

Comments
 (0)