@@ -36,7 +36,32 @@ defmodule Nx.Defn.Composite do
3636 |> Enum . all? ( fn { l , r } -> compatible? ( l , r , fun ) end )
3737 end
3838
39- def compatible? ( % mod { } = left , % mod { } = right , fun ) do
39+ def compatible? ( left , right , fun ) when is_struct ( left ) and is_struct ( right ) do
40+ cond do
41+ Nx.Block . block? ( left ) and Nx.Block . block? ( right ) ->
42+ left == right
43+
44+ left . __struct__ == right . __struct__ ->
45+ compatible_struct ( left , right , fun )
46+
47+ true ->
48+ false
49+ end
50+ end
51+
52+ def compatible? ( left , right , fun ) when map_size ( left ) == map_size ( right ) do
53+ Enum . all? ( left , fn { k , v1 } ->
54+ case right do
55+ % { ^ k => v2 } -> compatible? ( v1 , v2 , fun )
56+ % { } -> false
57+ end
58+ end )
59+ end
60+
61+ def compatible? ( _ , _ , _ ) ,
62+ do: false
63+
64+ defp compatible_struct ( left , right , fun ) do
4065 # LazyContainer is fully recursive but we don't want to go full recursive
4166 # unless we have to, so we can also compare structures along the way.
4267 { left , right } =
@@ -59,21 +84,6 @@ defmodule Nx.Defn.Composite do
5984 Enum . zip ( left , right ) |> Enum . all? ( fn { l , r } -> compatible? ( l , r , fun ) end )
6085 end
6186
62- def compatible? ( % _ { } , % _ { } , _fun ) ,
63- do: false
64-
65- def compatible? ( left , right , fun ) when map_size ( left ) == map_size ( right ) do
66- Enum . all? ( left , fn { k , v1 } ->
67- case right do
68- % { ^ k => v2 } -> compatible? ( v1 , v2 , fun )
69- % { } -> false
70- end
71- end )
72- end
73-
74- def compatible? ( _ , _ , _ ) ,
75- do: false
76-
7787 @ doc """
7888 Counts the number of non-composite types in the composite type.
7989
@@ -89,7 +99,14 @@ defmodule Nx.Defn.Composite do
8999 """
90100 def count ( tree ) , do: count ( tree , 0 )
91101 defp count ( tensor , acc ) when is_tensor ( tensor ) , do: acc + 1
92- defp count ( container , acc ) , do: Nx.Container . reduce ( container , acc , & count / 2 )
102+
103+ defp count ( other , acc ) do
104+ if Nx.Block . block? ( other ) do
105+ acc
106+ else
107+ Nx.Container . reduce ( other , acc , & count / 2 )
108+ end
109+ end
93110
94111 @ doc """
95112 Traverses recursively the given composite types with `fun`.
@@ -117,8 +134,13 @@ defmodule Nx.Defn.Composite do
117134 def traverse ( expr , acc , fun ) when is_tensor ( expr ) and is_function ( fun , 2 ) ,
118135 do: fun . ( expr , acc )
119136
120- def traverse ( container , acc , fun ) ,
121- do: Nx.Container . traverse ( container , acc , & traverse ( & 1 , & 2 , fun ) )
137+ def traverse ( expr , acc , fun ) when is_function ( fun , 2 ) do
138+ if Nx.Block . block? ( expr ) do
139+ { expr , acc }
140+ else
141+ Nx.Container . traverse ( expr , acc , & traverse ( & 1 , & 2 , fun ) )
142+ end
143+ end
122144
123145 @ doc """
124146 Reduces recursively the given composite types with `acc` and `fun`.
@@ -132,8 +154,13 @@ defmodule Nx.Defn.Composite do
132154 def reduce ( expr , acc , fun ) when is_tensor ( expr ) and is_function ( fun , 2 ) ,
133155 do: fun . ( expr , acc )
134156
135- def reduce ( container , acc , fun ) ,
136- do: Nx.Container . reduce ( container , acc , & reduce ( & 1 , & 2 , fun ) )
157+ def reduce ( expr , acc , fun ) when is_function ( fun , 2 ) do
158+ if Nx.Block . block? ( expr ) do
159+ acc
160+ else
161+ Nx.Container . reduce ( expr , acc , & reduce ( & 1 , & 2 , fun ) )
162+ end
163+ end
137164
138165 @ doc """
139166 Flattens recursively the given list of composite types.
@@ -163,6 +190,13 @@ defmodule Nx.Defn.Composite do
163190 when is_number ( number ) or is_struct ( number , Complex ) ,
164191 do: [ number | acc ]
165192
166- defp flatten_each ( container , acc ) ,
167- do: Nx.Container . reduce ( container , acc , & flatten_each / 2 )
193+ defp flatten_each ( other , acc ) do
194+ cond do
195+ Nx.Block . block? ( other ) ->
196+ [ other | acc ]
197+
198+ true ->
199+ Nx.Container . reduce ( other , acc , & flatten_each / 2 )
200+ end
201+ end
168202end
0 commit comments