@@ -134,25 +134,13 @@ defmodule Nx.Defn.Evaluator do
134134 end
135135
136136 defp compute_cache ( :block , % { data: % Expr { args: args } } , state , cache ) do
137- [ struct , in_args , expr , _callback ] = args
138- % module { } = struct
137+ [ struct , in_args , expr , callback ] = args
139138
140139 { call_prefix , call_suffix } = Enum . split_while ( in_args , & ( not is_list ( & 1 ) ) )
141140 { call_prefix , cache } = Enum . map_reduce ( call_prefix , cache , & compute_cache ( & 1 , state , & 2 ) )
142141 in_args = call_prefix ++ call_suffix
143- key = computation_key ( module , call_prefix )
144142
145- { { expr , expr_cache } , cache } =
146- case cache do
147- % { ^ key => optional_expr_cache } ->
148- { optional_expr_cache , cache }
149-
150- % { } ->
151- optional_expr_cache = init_compute_cache ( expr , state )
152- { optional_expr_cache , Map . put ( cache , key , optional_expr_cache ) }
153- end
154-
155- { [ struct , in_args , expr , expr_cache ] , cache }
143+ { [ struct , in_args , expr , callback ] , cache }
156144 end
157145
158146 defp compute_cache ( :cond , % { data: % Expr { args: [ clauses , last ] } } , state , cache ) do
@@ -229,16 +217,6 @@ defmodule Nx.Defn.Evaluator do
229217 Tree . apply_args ( tensor , cache , & compute_cache ( & 1 , state , & 2 ) )
230218 end
231219
232- defp computation_key ( op , args ) do
233- keys =
234- Enum . map ( args , fn
235- % Nx.Tensor { shape: shape , names: names , type: type } -> { type , shape , names }
236- opts -> opts
237- end )
238-
239- { op , keys }
240- end
241-
242220 ## Evaluation
243221
244222 defp eval ( % Nx.Tensor { data: % Expr { op: :tensor , args: [ t ] } } , _state , caches ) do
@@ -365,7 +343,7 @@ defmodule Nx.Defn.Evaluator do
365343 { { } , caches }
366344 end
367345
368- defp eval_apply ( :block , [ struct , in_args , expr , expr_cache ] , ans , state , caches ) do
346+ defp eval_apply ( :block , [ struct , in_args , expr , callback ] , ans , state , caches ) do
369347 { in_args , caches } = Enum . map_reduce ( in_args , caches , & eval ( & 1 , state , & 2 ) )
370348 { param_prefix , _ } = Enum . split_while ( in_args , & ( not is_list ( & 1 ) ) )
371349 backend = Nx.Shared . list_impl! ( param_prefix )
@@ -376,16 +354,7 @@ defmodule Nx.Defn.Evaluator do
376354 _ -> ans
377355 end
378356
379- fun =
380- Nx.Defn.Compiler . fun (
381- length ( in_args ) + 1 ,
382- fn args ->
383- [ struct | tensors ] = args
384- block_apply_default ( expr , state , expr_cache , struct , tensors )
385- end
386- )
387-
388- { backend . block ( struct , out , in_args , fun ) , caches }
357+ { backend . block ( struct , out , in_args , callback ) , caches }
389358 end
390359
391360 defp eval_apply ( :runtime_call , [ expr , fun , out_template , opts ] , _ans , state , caches ) do
@@ -429,11 +398,6 @@ defmodule Nx.Defn.Evaluator do
429398 { apply ( mod , op , args ) , caches }
430399 end
431400
432- defp block_apply_default ( expr , state , expr_cache , _struct , args ) when is_list ( args ) do
433- params = Enum . map ( args , & fn -> & 1 end )
434- elem ( composite_eval ( expr , % { state | params: params } , [ expr_cache ] ) , 0 )
435- end
436-
437401 ## Control flow helpers
438402
439403 defp while ( acc , condition , block , state , caches ) do
0 commit comments