Skip to content

Commit 1c3d86e

Browse files
authored
fix: Nx.block with EXLA JIT (#1750)
1 parent fb89cf4 commit 1c3d86e

2 files changed

Lines changed: 40 additions & 40 deletions

File tree

exla/test/exla/defn/api_test.exs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,42 @@ defmodule EXLA.Defn.APITest do
320320
end
321321
end
322322

323+
describe "block ops via Evaluator with EXLA.Backend" do
324+
# Regression test: before the fix, eval_apply(:block, ...) wrapped block_apply_default
325+
# in a Nx.Defn.Compiler.fun and passed it to backend.block. EXLA.Backend.block JIT-traces
326+
# that fun through Nx.Defn.Expr.block, calling it with Expr params. block_apply_default
327+
# then hit composite_eval which raised on Expr tensors at the :parameter guard.
328+
#
329+
# The fix: pass the original callback directly so JIT-tracing works as intended.
330+
setup do
331+
Nx.default_backend({EXLA.Backend, client: :host})
332+
# Override the global EXLA compiler set in test_helper so Evaluator handles defn dispatch
333+
Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
334+
:ok
335+
end
336+
337+
defn block_top_k(t), do: Nx.top_k(t, k: 2)
338+
defn block_cumulative_sum(t), do: Nx.cumulative_sum(t)
339+
defn block_logical_not(t), do: Nx.logical_not(t)
340+
341+
test "top_k" do
342+
t = Nx.tensor([3.0, 1.0, 4.0, 1.0, 5.0], backend: {EXLA.Backend, client: :host})
343+
{values, indices} = block_top_k(t)
344+
assert_equal(values, Nx.tensor([5.0, 4.0]))
345+
assert_equal(indices, Nx.tensor([4, 2]))
346+
end
347+
348+
test "cumulative_sum" do
349+
t = Nx.tensor([1.0, 2.0, 3.0, 4.0], backend: {EXLA.Backend, client: :host})
350+
assert_equal(block_cumulative_sum(t), Nx.tensor([1.0, 3.0, 6.0, 10.0]))
351+
end
352+
353+
test "logical_not" do
354+
t = Nx.tensor([1, 0, 1, 0], backend: {EXLA.Backend, client: :host})
355+
assert_equal(block_logical_not(t), Nx.tensor([0, 1, 0, 1], type: :u8))
356+
end
357+
end
358+
323359
describe "telemetry" do
324360
defn telemetry_add_two(a, b), do: a + b
325361

nx/lib/nx/defn/evaluator.ex

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -134,25 +134,13 @@ defmodule Nx.Defn.Evaluator do
134134
end
135135

136136
defp compute_cache(:block, %{data: %Expr{args: args}}, state, cache) do
137-
[struct, in_args, expr, _callback] = args
138-
%module{} = struct
137+
[struct, in_args, expr, callback] = args
139138

140139
{call_prefix, call_suffix} = Enum.split_while(in_args, &(not is_list(&1)))
141140
{call_prefix, cache} = Enum.map_reduce(call_prefix, cache, &compute_cache(&1, state, &2))
142141
in_args = call_prefix ++ call_suffix
143-
key = computation_key(module, call_prefix)
144142

145-
{{expr, expr_cache}, cache} =
146-
case cache do
147-
%{^key => optional_expr_cache} ->
148-
{optional_expr_cache, cache}
149-
150-
%{} ->
151-
optional_expr_cache = init_compute_cache(expr, state)
152-
{optional_expr_cache, Map.put(cache, key, optional_expr_cache)}
153-
end
154-
155-
{[struct, in_args, expr, expr_cache], cache}
143+
{[struct, in_args, expr, callback], cache}
156144
end
157145

158146
defp compute_cache(:cond, %{data: %Expr{args: [clauses, last]}}, state, cache) do
@@ -229,16 +217,6 @@ defmodule Nx.Defn.Evaluator do
229217
Tree.apply_args(tensor, cache, &compute_cache(&1, state, &2))
230218
end
231219

232-
defp computation_key(op, args) do
233-
keys =
234-
Enum.map(args, fn
235-
%Nx.Tensor{shape: shape, names: names, type: type} -> {type, shape, names}
236-
opts -> opts
237-
end)
238-
239-
{op, keys}
240-
end
241-
242220
## Evaluation
243221

244222
defp eval(%Nx.Tensor{data: %Expr{op: :tensor, args: [t]}}, _state, caches) do
@@ -365,7 +343,7 @@ defmodule Nx.Defn.Evaluator do
365343
{{}, caches}
366344
end
367345

368-
defp eval_apply(:block, [struct, in_args, expr, expr_cache], ans, state, caches) do
346+
defp eval_apply(:block, [struct, in_args, expr, callback], ans, state, caches) do
369347
{in_args, caches} = Enum.map_reduce(in_args, caches, &eval(&1, state, &2))
370348
{param_prefix, _} = Enum.split_while(in_args, &(not is_list(&1)))
371349
backend = Nx.Shared.list_impl!(param_prefix)
@@ -376,16 +354,7 @@ defmodule Nx.Defn.Evaluator do
376354
_ -> ans
377355
end
378356

379-
fun =
380-
Nx.Defn.Compiler.fun(
381-
length(in_args) + 1,
382-
fn args ->
383-
[struct | tensors] = args
384-
block_apply_default(expr, state, expr_cache, struct, tensors)
385-
end
386-
)
387-
388-
{backend.block(struct, out, in_args, fun), caches}
357+
{backend.block(struct, out, in_args, callback), caches}
389358
end
390359

391360
defp eval_apply(:runtime_call, [expr, fun, out_template, opts], _ans, state, caches) do
@@ -429,11 +398,6 @@ defmodule Nx.Defn.Evaluator do
429398
{apply(mod, op, args), caches}
430399
end
431400

432-
defp block_apply_default(expr, state, expr_cache, _struct, args) when is_list(args) do
433-
params = Enum.map(args, &fn -> &1 end)
434-
elem(composite_eval(expr, %{state | params: params}, [expr_cache]), 0)
435-
end
436-
437401
## Control flow helpers
438402

439403
defp while(acc, condition, block, state, caches) do

0 commit comments

Comments
 (0)