Skip to content

Commit 613a0a3

Browse files
committed
Handle Nx.Block structs in LazyContainer/Composite for defn/JIT; trim LU block + EXLA subfn names
1 parent ceec8ad commit 613a0a3

4 files changed

Lines changed: 104 additions & 34 deletions

File tree

exla/test/exla/nx_linalg_doctest_test.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ defmodule EXLA.NxLinAlgDoctestTest do
2525
least_squares: 3,
2626
determinant: 1,
2727
matrix_power: 2,
28-
lu: 2,
28+
lu: 1,
2929
qr: 2
3030
]
3131

nx/lib/nx/block.ex

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,35 @@ end
7373
defmodule Nx.Block.IFFT2 do
7474
defstruct eps: nil, lengths: nil, axes: nil
7575
end
76+
77+
defmodule Nx.Block do
78+
@moduledoc false
79+
80+
@structs [
81+
Nx.Block.LogicalNot,
82+
Nx.Block.Phase,
83+
Nx.Block.Cholesky,
84+
Nx.Block.Solve,
85+
Nx.Block.QR,
86+
Nx.Block.Eigh,
87+
Nx.Block.SVD,
88+
Nx.Block.LU,
89+
Nx.Block.Determinant,
90+
Nx.Block.AllClose,
91+
Nx.Block.CumulativeSum,
92+
Nx.Block.CumulativeProduct,
93+
Nx.Block.CumulativeMin,
94+
Nx.Block.CumulativeMax,
95+
Nx.Block.Take,
96+
Nx.Block.TakeAlongAxis,
97+
Nx.Block.TopK,
98+
Nx.Block.FFT2,
99+
Nx.Block.IFFT2
100+
]
101+
102+
for mod <- @structs do
103+
def block?(%unquote(mod){}), do: true
104+
end
105+
106+
def block?(_), do: false
107+
end

nx/lib/nx/defn/composite.ex

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,32 @@ defmodule Nx.Defn.Composite do
3636
|> Enum.all?(fn {l, r} -> compatible?(l, r, fun) end)
3737
end
3838

39-
def compatible?(%mod{} = left, %mod{} = right, fun) do
39+
def compatible?(left, right, fun) when is_struct(left) and is_struct(right) do
40+
cond do
41+
Nx.Block.block?(left) and Nx.Block.block?(right) ->
42+
left == right
43+
44+
left.__struct__ == right.__struct__ ->
45+
compatible_struct(left, right, fun)
46+
47+
true ->
48+
false
49+
end
50+
end
51+
52+
def compatible?(left, right, fun) when map_size(left) == map_size(right) do
53+
Enum.all?(left, fn {k, v1} ->
54+
case right do
55+
%{^k => v2} -> compatible?(v1, v2, fun)
56+
%{} -> false
57+
end
58+
end)
59+
end
60+
61+
def compatible?(_, _, _),
62+
do: false
63+
64+
defp compatible_struct(left, right, fun) do
4065
# LazyContainer is fully recursive but we don't want to go full recursive
4166
# unless we have to, so we can also compare structures along the way.
4267
{left, right} =
@@ -59,21 +84,6 @@ defmodule Nx.Defn.Composite do
5984
Enum.zip(left, right) |> Enum.all?(fn {l, r} -> compatible?(l, r, fun) end)
6085
end
6186

62-
def compatible?(%_{}, %_{}, _fun),
63-
do: false
64-
65-
def compatible?(left, right, fun) when map_size(left) == map_size(right) do
66-
Enum.all?(left, fn {k, v1} ->
67-
case right do
68-
%{^k => v2} -> compatible?(v1, v2, fun)
69-
%{} -> false
70-
end
71-
end)
72-
end
73-
74-
def compatible?(_, _, _),
75-
do: false
76-
7787
@doc """
7888
Counts the number of non-composite types in the composite type.
7989
@@ -89,7 +99,14 @@ defmodule Nx.Defn.Composite do
8999
"""
90100
def count(tree), do: count(tree, 0)
91101
defp count(tensor, acc) when is_tensor(tensor), do: acc + 1
92-
defp count(container, acc), do: Nx.Container.reduce(container, acc, &count/2)
102+
103+
defp count(other, acc) do
104+
if Nx.Block.block?(other) do
105+
acc
106+
else
107+
Nx.Container.reduce(other, acc, &count/2)
108+
end
109+
end
93110

94111
@doc """
95112
Traverses recursively the given composite types with `fun`.
@@ -117,8 +134,13 @@ defmodule Nx.Defn.Composite do
117134
def traverse(expr, acc, fun) when is_tensor(expr) and is_function(fun, 2),
118135
do: fun.(expr, acc)
119136

120-
def traverse(container, acc, fun),
121-
do: Nx.Container.traverse(container, acc, &traverse(&1, &2, fun))
137+
def traverse(expr, acc, fun) when is_function(fun, 2) do
138+
if Nx.Block.block?(expr) do
139+
{expr, acc}
140+
else
141+
Nx.Container.traverse(expr, acc, &traverse(&1, &2, fun))
142+
end
143+
end
122144

123145
@doc """
124146
Reduces recursively the given composite types with `acc` and `fun`.
@@ -132,8 +154,13 @@ defmodule Nx.Defn.Composite do
132154
def reduce(expr, acc, fun) when is_tensor(expr) and is_function(fun, 2),
133155
do: fun.(expr, acc)
134156

135-
def reduce(container, acc, fun),
136-
do: Nx.Container.reduce(container, acc, &reduce(&1, &2, fun))
157+
def reduce(expr, acc, fun) when is_function(fun, 2) do
158+
if Nx.Block.block?(expr) do
159+
acc
160+
else
161+
Nx.Container.reduce(expr, acc, &reduce(&1, &2, fun))
162+
end
163+
end
137164

138165
@doc """
139166
Flattens recursively the given list of composite types.
@@ -163,6 +190,13 @@ defmodule Nx.Defn.Composite do
163190
when is_number(number) or is_struct(number, Complex),
164191
do: [number | acc]
165192

166-
defp flatten_each(container, acc),
167-
do: Nx.Container.reduce(container, acc, &flatten_each/2)
193+
defp flatten_each(other, acc) do
194+
cond do
195+
Nx.Block.block?(other) ->
196+
[other | acc]
197+
198+
true ->
199+
Nx.Container.reduce(other, acc, &flatten_each/2)
200+
end
201+
end
168202
end

nx/lib/nx/lazy_container.ex

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,20 @@ end
107107

108108
defimpl Nx.LazyContainer, for: Any do
109109
def traverse(data, acc, fun) do
110-
case Nx.Container.impl_for(data) do
111-
nil ->
112-
raise Protocol.UndefinedError,
113-
protocol: @protocol,
114-
value: data,
115-
description:
116-
"data-structures given to defn/Nx must implement either Nx.LazyContainer or Nx.Container"
117-
118-
impl ->
119-
impl.traverse(data, acc, &Nx.LazyContainer.traverse(&1, &2, fun))
110+
if Nx.Block.block?(data) do
111+
{data, acc}
112+
else
113+
case Nx.Container.impl_for(data) do
114+
nil ->
115+
raise Protocol.UndefinedError,
116+
protocol: @protocol,
117+
value: data,
118+
description:
119+
"data-structures given to defn/Nx must implement either Nx.LazyContainer or Nx.Container"
120+
121+
impl ->
122+
impl.traverse(data, acc, &Nx.LazyContainer.traverse(&1, &2, fun))
123+
end
120124
end
121125
end
122126
end

0 commit comments

Comments
 (0)