@@ -2093,26 +2093,13 @@ def _relu_handler(P: MLXProgramBuilder, n: Node) -> Slot:
20932093 require_kwargs (P .kwargs (n ), set (), "aten.relu" )
20942094 (x ,) = args # x is already a Slot
20952095
2096- # Get input dtype
20972096 x_meta = n .args [0 ].meta .get ("val" )
20982097 if x_meta is None :
20992098 raise ValueError ("Input tensor metadata not found for relu" )
21002099 dtype = x_meta .dtype
21012100
2102- _ , zero_slot = P . make_tmp_slot ( )
2101+ zero_slot = emit_lifted_constant ( P , 0.0 , dtype )
21032102
2104- # Emit FullNode to create a scalar zero (shape=[])
2105- # Maximum will broadcast this scalar to match input shape
2106- P .emit (
2107- FullNode (
2108- shape = [], # Scalar (will be broadcast in maximum)
2109- v = FloatOrVid .from_literal (0.0 ),
2110- scalar_type = torch_dtype_to_scalar_type (dtype ),
2111- out = P .slot_to_tid (zero_slot ),
2112- )
2113- )
2114-
2115- # Emit MaximumNode(x, scalar_zero)
21162103 out = P .make_or_get_slot (n )
21172104 P .emit (
21182105 MaximumNode (
@@ -2264,26 +2251,12 @@ def _clamp_handler(P: MLXProgramBuilder, n: Node) -> Slot:
22642251 )
22652252 return out
22662253
2267- # Helper to create a scalar constant slot
2268- def make_scalar_slot (val ):
2269- _ , slot = P .make_tmp_slot ()
2270- P .emit (
2271- FullNode (
2272- shape = [], # Scalar
2273- v = FloatOrVid .from_literal (float (val )),
2274- scalar_type = torch_dtype_to_scalar_type (dtype ),
2275- out = P .slot_to_tid (slot ),
2276- )
2277- )
2278- return slot
2279-
22802254 current = x
22812255
22822256 # Apply max constraint first: min(x, max_val)
22832257 if max_val is not None :
2284- max_slot = make_scalar_slot ( max_val )
2258+ max_slot = emit_lifted_constant ( P , float ( max_val ), dtype )
22852259 if min_val is not None :
2286- # Need a temp slot since we have both constraints
22872260 _ , tmp = P .make_tmp_slot ()
22882261 P .emit (
22892262 MinimumNode (
@@ -2294,7 +2267,6 @@ def make_scalar_slot(val):
22942267 )
22952268 current = tmp
22962269 else :
2297- # Only max constraint, output directly
22982270 P .emit (
22992271 MinimumNode (
23002272 a = P .slot_to_tid (current ),
@@ -2306,7 +2278,7 @@ def make_scalar_slot(val):
23062278
23072279 # Apply min constraint: max(current, min_val)
23082280 if min_val is not None :
2309- min_slot = make_scalar_slot ( min_val )
2281+ min_slot = emit_lifted_constant ( P , float ( min_val ), dtype )
23102282 P .emit (
23112283 MaximumNode (
23122284 a = P .slot_to_tid (current ),
@@ -2416,16 +2388,7 @@ def reshape_for_broadcast(slot, name_suffix):
24162388 )
24172389
24182390 # Step 2: var_eps = var + eps
2419- # Create eps as a scalar using FullNode (broadcasts correctly with var)
2420- _ , eps_slot = P .make_tmp_slot ()
2421- P .emit (
2422- FullNode (
2423- out = P .slot_to_tid (eps_slot ),
2424- shape = [], # 0-D scalar
2425- v = FloatOrVid .from_literal (float (eps )),
2426- scalar_type = torch_dtype_to_scalar_type (torch .float32 ),
2427- )
2428- )
2391+ eps_slot = emit_lifted_constant (P , float (eps ), torch .float32 )
24292392 _ , tmp_var_eps = P .make_tmp_slot ()
24302393 P .emit (
24312394 AddNode (
@@ -2944,12 +2907,7 @@ def _logical_or_handler(P: MLXProgramBuilder, n: Node) -> Slot:
29442907
29452908@REGISTRY .register (target = [torch .ops .aten .scalar_tensor .default ])
29462909def _scalar_tensor_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
2947- """Handle aten.scalar_tensor - create a 0-D tensor from a scalar value.
2948-
2949- scalar_tensor(scalar, *, dtype=None, layout=None, device=None, pin_memory=None) -> Tensor
2950-
2951- This is equivalent to torch.full([], scalar, dtype=dtype).
2952- """
2910+ """This is equivalent to torch.full([], scalar, dtype=dtype)."""
29532911 args = P .args (n )
29542912 kwargs = P .kwargs (n )
29552913 require_args (args , 1 , 1 , "aten.scalar_tensor" )
@@ -3438,20 +3396,9 @@ def _pow_handler(P: MLXProgramBuilder, n: Node) -> Slot:
34383396
34393397 # Handle scalar exponent by creating a scalar full tensor that will broadcast
34403398 if not isinstance (b , Slot ):
3441- # Get dtype from input tensor's meta
34423399 input_meta = n .args [0 ].meta .get ("val" )
34433400 dtype = input_meta .dtype if input_meta is not None else torch .float32
3444-
3445- _ , b_slot = P .make_tmp_slot ()
3446- P .emit (
3447- FullNode (
3448- out = P .slot_to_tid (b_slot ),
3449- shape = [], # 0-D scalar - broadcasts correctly
3450- v = FloatOrVid .from_literal (float (b )),
3451- scalar_type = torch_dtype_to_scalar_type (dtype ),
3452- )
3453- )
3454- b = b_slot
3401+ b = emit_lifted_constant (P , float (b ), dtype )
34553402
34563403 out = P .make_or_get_slot (n )
34573404 P .emit (PowerNode (a = P .slot_to_tid (a ), b = P .slot_to_tid (b ), out = P .slot_to_tid (out )))
0 commit comments