Skip to content

Commit bed10bc

Browse files
committed
Remove Nx.Container and Block.name from block.ex, update Lu.lu from /2 to /1
1 parent 69dff5c commit bed10bc

8 files changed

Lines changed: 31 additions & 77 deletions

File tree

exla/lib/exla/defn.ex

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -763,20 +763,18 @@ defmodule EXLA.Defn do
763763

764764
defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do
765765
[struct, in_args, expr, _callback] = args
766-
op = Nx.Block.name(struct)
767766

768-
{call_prefix, _opts_suffix} = Enum.split_while(in_args, &(not is_list(&1)))
769-
770-
{call_args, cache} = Enum.map_reduce(call_prefix, cache, &recur_operator(&1, state, &2))
771-
key = computation_key(op, [struct | call_args])
767+
{call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2))
768+
key = computation_key(struct.__struct__, [struct | call_args])
772769

773770
{call_body, cache} =
774771
case cache do
775772
%{^key => computation} ->
776773
{computation, cache}
777774

778775
%{} ->
779-
{computation, cache} = block_computation("block", call_args, expr, state, cache)
776+
{computation, cache} =
777+
block_computation(block_subfunction_description(struct), call_args, expr, state, cache)
780778
{computation, Map.put(cache, key, computation)}
781779
end
782780

@@ -1831,8 +1829,15 @@ defmodule EXLA.Defn do
18311829
{region, merge_outfeed(cache, comp_cache)}
18321830
end
18331831

1834-
defp block_computation(name, args, expr, %{builder: %Function{}} = state, cache) do
1835-
%Function{module: module, name: name} = subbuilder(state.builder, name)
1832+
defp block_subfunction_description(%_{} = struct) do
1833+
struct.__struct__
1834+
|> Module.split()
1835+
|> List.last()
1836+
|> Macro.underscore()
1837+
end
1838+
1839+
defp block_computation(description, args, expr, %{builder: %Function{}} = state, cache) do
1840+
%Function{module: module, name: name} = subbuilder(state.builder, description)
18361841

18371842
arg_typespecs = Enum.map(args, &Value.get_typespec/1)
18381843
out_typespecs = container_to_typespecs(expr)

nx/lib/nx/block.ex

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,103 +15,61 @@ defmodule Nx.Block.Solve do
1515
end
1616

1717
defmodule Nx.Block.QR do
18-
@derive {Nx.Container, containers: [], keep: [:eps, :mode]}
1918
defstruct eps: 1.0e-10, mode: :reduced
2019
end
2120

2221
defmodule Nx.Block.Eigh do
23-
@derive {Nx.Container, containers: [], keep: [:max_iter, :eps]}
2422
defstruct max_iter: 1000, eps: 1.0e-4
2523
end
2624

2725
defmodule Nx.Block.SVD do
28-
@derive {Nx.Container, containers: [], keep: [:max_iter, :full_matrices?]}
2926
defstruct max_iter: 100, full_matrices?: true
3027
end
3128

3229
defmodule Nx.Block.LU do
33-
@derive {Nx.Container, containers: [], keep: [:eps]}
34-
defstruct eps: 1.0e-10
30+
defstruct []
3531
end
3632

3733
defmodule Nx.Block.Determinant do
3834
defstruct []
3935
end
4036

4137
defmodule Nx.Block.AllClose do
42-
@derive {Nx.Container, containers: [], keep: [:equal_nan, :rtol, :atol]}
4338
defstruct equal_nan: false, rtol: 1.0e-5, atol: 1.0e-8
4439
end
4540

4641
defmodule Nx.Block.CumulativeSum do
47-
@derive {Nx.Container, containers: [], keep: [:axis, :reverse]}
4842
defstruct axis: 0, reverse: false
4943
end
5044

5145
defmodule Nx.Block.CumulativeProduct do
52-
@derive {Nx.Container, containers: [], keep: [:axis, :reverse]}
5346
defstruct axis: 0, reverse: false
5447
end
5548

5649
defmodule Nx.Block.CumulativeMin do
57-
@derive {Nx.Container, containers: [], keep: [:axis, :reverse]}
5850
defstruct axis: 0, reverse: false
5951
end
6052

6153
defmodule Nx.Block.CumulativeMax do
62-
@derive {Nx.Container, containers: [], keep: [:axis, :reverse]}
6354
defstruct axis: 0, reverse: false
6455
end
6556

6657
defmodule Nx.Block.Take do
67-
@derive {Nx.Container, containers: [], keep: [:axis]}
6858
defstruct axis: 0
6959
end
7060

7161
defmodule Nx.Block.TakeAlongAxis do
72-
@derive {Nx.Container, containers: [], keep: [:axis]}
7362
defstruct axis: 0
7463
end
7564

7665
defmodule Nx.Block.TopK do
77-
@derive {Nx.Container, containers: [], keep: [:k]}
7866
defstruct k: 1
7967
end
8068

8169
defmodule Nx.Block.FFT2 do
82-
@derive {Nx.Container, containers: [], keep: [:eps, :lengths, :axes]}
8370
defstruct eps: nil, lengths: nil, axes: nil
8471
end
8572

8673
defmodule Nx.Block.IFFT2 do
87-
@derive {Nx.Container, containers: [], keep: [:eps, :lengths, :axes]}
8874
defstruct eps: nil, lengths: nil, axes: nil
8975
end
90-
91-
defmodule Nx.Block do
92-
@moduledoc false
93-
94-
def name(%{__struct__: module}) do
95-
case module do
96-
Nx.Block.LogicalNot -> :logical_not
97-
Nx.Block.Phase -> :phase
98-
Nx.Block.AllClose -> :all_close
99-
Nx.Block.CumulativeSum -> :cumulative_sum
100-
Nx.Block.CumulativeProduct -> :cumulative_product
101-
Nx.Block.CumulativeMin -> :cumulative_min
102-
Nx.Block.CumulativeMax -> :cumulative_max
103-
Nx.Block.Cholesky -> :cholesky
104-
Nx.Block.Solve -> :solve
105-
Nx.Block.QR -> :qr
106-
Nx.Block.Eigh -> :eigh
107-
Nx.Block.SVD -> :svd
108-
Nx.Block.LU -> :lu
109-
Nx.Block.Determinant -> :determinant
110-
Nx.Block.Take -> :take
111-
Nx.Block.TakeAlongAxis -> :take_along_axis
112-
Nx.Block.TopK -> :top_k
113-
Nx.Block.FFT2 -> :fft2
114-
Nx.Block.IFFT2 -> :ifft2
115-
end
116-
end
117-
end

nx/lib/nx/defn/composite.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ defmodule Nx.Defn.Composite do
33
Functions to deal with composite data types.
44
55
Composite data-types are traversed according to `Nx.Container`.
6-
If a regular tensor is given, it is individually traversed.
6+
If a regular tensor is given, it is individually traversed.
77
Numerical values, such as integers, floats, and complex numbers
88
are not normalized before hand. Use `Nx.to_tensor/1` to do so.
99

nx/lib/nx/defn/evaluator.ex

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,8 @@ defmodule Nx.Defn.Evaluator do
136136
defp compute_cache(:block, %{data: %Expr{args: args}}, state, cache) do
137137
[struct, in_args, expr, _callback] = args
138138

139-
{call_prefix, call_suffix} = Enum.split_while(in_args, &(not is_list(&1)))
140-
{call_prefix, cache} = Enum.map_reduce(call_prefix, cache, &compute_cache(&1, state, &2))
141-
in_args = call_prefix ++ call_suffix
142-
key = computation_key(Nx.Block.name(struct), call_prefix)
139+
{in_args, cache} = Enum.map_reduce(in_args, cache, &compute_cache(&1, state, &2))
140+
key = computation_key(struct.__struct__, in_args)
143141

144142
{{expr, expr_cache}, cache} =
145143
case cache do
@@ -366,17 +364,16 @@ defmodule Nx.Defn.Evaluator do
366364

367365
defp eval_apply(:block, [struct, in_args, expr, expr_cache], ans, state, caches) do
368366
{in_args, caches} = Tree.map_block_args(in_args, caches, &eval(&1, state, &2))
369-
{param_prefix, _} = Enum.split_while(in_args, &(not is_list(&1)))
370-
backend = Nx.Shared.list_impl!(param_prefix)
367+
backend = Nx.Shared.list_impl!(in_args)
371368

372369
out =
373370
case ans do
374371
%{type: {:tuple, _}} -> expr
375372
_ -> ans
376373
end
377374

378-
fun = block_default_fun(expr, state, expr_cache, length(param_prefix))
379-
{backend.block(struct, out, param_prefix, fun), caches}
375+
fun = block_default_fun(expr, state, expr_cache, length(in_args))
376+
{backend.block(struct, out, in_args, fun), caches}
380377
end
381378

382379
defp eval_apply(:runtime_call, [expr, fun, out_template, opts], _ans, state, caches) do

nx/lib/nx/defn/expr.ex

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ defmodule Nx.Defn.Expr do
4444
* `runtime_call(out, tensor_or_container, opts, fun)`
4545
4646
* `block(struct, block_args, default_expr, fun)` - `struct` is an `Nx.Block.*`
47-
value, `block_args` are the tensors and keyword options passed to `Nx.block/4`,
48-
`default_expr` is the traced default implementation, and `fun` is the block
49-
callback
47+
value, `block_args` are tensor expressions passed to `Nx.block/4` (options
48+
live on `struct`), `default_expr` is the traced default implementation, and
49+
`fun` is the block callback
5050
5151
`defn` compilers must handle said nodes accordingly.
5252
"""
@@ -426,10 +426,9 @@ defmodule Nx.Defn.Expr do
426426
end
427427

428428
defp expr_block(struct, in_args, fun) do
429-
{args, opts} = Enum.split_while(in_args, &(not is_list(&1)))
430-
params = Enum.with_index(args, &parameter/2)
429+
params = Enum.with_index(in_args, &parameter/2)
431430

432-
case apply(fun, [struct | params ++ opts]) do
431+
case apply(fun, [struct | params]) do
433432
%{data: %{context: context}} = res ->
434433
expr(res, context, :block, [struct, in_args, res, fun])
435434

nx/lib/nx/lin_alg.ex

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,10 +1527,6 @@ defmodule Nx.LinAlg do
15271527
@doc """
15281528
Calculates the A = PLU decomposition of batched square 2-D matrices A.
15291529
1530-
## Options
1531-
1532-
* `:eps` - Rounding error threshold that can be applied during the factorization
1533-
15341530
## Examples
15351531
15361532
iex> {p, l, u} = Nx.LinAlg.lu(Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
@@ -1734,8 +1730,7 @@ defmodule Nx.LinAlg do
17341730
iex> Nx.LinAlg.lu(Nx.tensor([[1, 1, 1, 1], [-1, 4, 4, -1], [4, -2, 2, 0]]))
17351731
** (ArgumentError) tensor must be a square matrix or a batch of square matrices, got shape: {3, 4}
17361732
"""
1737-
def lu(tensor, opts \\ []) do
1738-
opts = keyword!(opts, eps: 1.0e-10)
1733+
def lu(tensor) do
17391734
%T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor)
17401735
%T{type: type, shape: shape} = tensor = Nx.devectorize(tensor)
17411736

@@ -1748,8 +1743,8 @@ defmodule Nx.LinAlg do
17481743
%{tensor | type: output_type, shape: l_shape, names: names},
17491744
%{tensor | type: output_type, shape: u_shape, names: names}}
17501745

1751-
Nx.block(struct(Nx.Block.LU, opts), [tensor], output, fn %Nx.Block.LU{} = opts, t ->
1752-
Nx.LinAlg.LU.lu(opts, t)
1746+
Nx.block(%Nx.Block.LU{}, [tensor], output, fn %Nx.Block.LU{}, t ->
1747+
Nx.LinAlg.LU.lu(t)
17531748
end)
17541749
|> Nx.vectorize(vectorized_axes)
17551750
end

nx/lib/nx/lin_alg/lu.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ defmodule Nx.LinAlg.LU do
22
@moduledoc false
33
import Nx.Defn
44

5-
defn lu(%Nx.Block.LU{} = _opts, a) do
5+
defn lu(a) do
66
vectorized_axes = a.vectorized_axes
77

88
result =

torchx/lib/torchx/backend.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ defmodule Torchx.Backend do
7474
qr_impl(hd(args), mode: struct.mode, eps: struct.eps)
7575

7676
Nx.Block.LU ->
77-
lu_impl(hd(args), eps: struct.eps)
77+
lu_impl(hd(args))
7878

7979
Nx.Block.Eigh ->
8080
eigh_impl(hd(args), max_iter: struct.max_iter, eps: struct.eps)
@@ -1224,7 +1224,7 @@ defmodule Torchx.Backend do
12241224
end
12251225
end
12261226

1227-
defp lu_impl(tensor, _opts) do
1227+
defp lu_impl(tensor) do
12281228
tensor =
12291229
if Nx.Type.integer?(tensor.type) do
12301230
Nx.as_type(tensor, {:f, 32})

0 commit comments

Comments
 (0)