@@ -311,7 +311,7 @@ def import_var(
311311
312312 if isinstance (var .type , NullType ):
313313 raise TypeError (
314- f"Computation graph contains a NaN. { var .type .why_null } "
314+ f"Computation graph contains a null type: { var } { var .type .why_null } "
315315 )
316316 if import_missing :
317317 self .add_input (var )
@@ -327,7 +327,7 @@ def import_node(
327327 reason : Optional [str ] = None ,
328328 import_missing : bool = False ,
329329 ) -> None :
330- """Recursively import everything between an `` Apply`` node and the `` FunctionGraph` `'s outputs.
330+ """Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
331331
332332 Parameters
333333 ----------
@@ -347,42 +347,62 @@ def import_node(
347347 # to know where to stop going down.)
348348 new_nodes = io_toposort (self .variables , apply_node .outputs )
349349
350- if check :
351- for node in new_nodes :
352- for var in node .inputs :
353- if (
354- var .owner is None
355- and not isinstance (var , AtomicVariable )
356- and var not in self .inputs
357- ):
358- if import_missing :
359- self .add_input (var )
360- else :
361- error_msg = (
362- f"Input { node .inputs .index (var )} ({ var } )"
363- " of the graph (indices start "
364- f"from 0), used to compute { node } , was not "
365- "provided and not given a value. Use the "
366- "Aesara flag exception_verbosity='high', "
367- "for more information on this error."
368- )
369- raise MissingInputError (error_msg , variable = var )
370-
371350 for node in new_nodes :
372- assert node not in self .apply_nodes
373- self .apply_nodes .add (node )
374- if not hasattr (node .tag , "imported_by" ):
375- node .tag .imported_by = []
376- node .tag .imported_by .append (str (reason ))
377- for output in node .outputs :
378- self .setup_var (output )
379- self .variables .add (output )
380- for i , input in enumerate (node .inputs ):
381- if input not in self .variables :
382- self .setup_var (input )
383- self .variables .add (input )
384- self .add_client (input , (node , i ))
385- self .execute_callbacks ("on_import" , node , reason )
351+ self ._import_node (
352+ node , check = check , reason = reason , import_missing = import_missing
353+ )
354+
355+ def _import_node (
356+ self ,
357+ apply_node : Apply ,
358+ check : bool = True ,
359+ reason : Optional [str ] = None ,
360+ import_missing : bool = False ,
361+ ) -> None :
362+ """Import a single node.
363+
364+ See `FunctionGraph.import_node`.
365+ """
366+ assert apply_node not in self .apply_nodes
367+
368+ for i , inp in enumerate (apply_node .inputs ):
369+ if (
370+ check
371+ and inp .owner is None
372+ and not isinstance (inp , AtomicVariable )
373+ and inp not in self .inputs
374+ ):
375+ if import_missing :
376+ self .add_input (inp )
377+ else :
378+ error_msg = (
379+ f"Input { apply_node .inputs .index (inp )} ({ inp } )"
380+ " of the graph (indices start "
381+ f"from 0), used to compute { apply_node } , was not "
382+ "provided and not given a value. Use the "
383+ "Aesara flag exception_verbosity='high', "
384+ "for more information on this error."
385+ )
386+ raise MissingInputError (error_msg , variable = inp )
387+
388+ if inp not in self .variables :
389+ self .setup_var (inp )
390+ self .variables .add (inp )
391+
392+ self .add_client (inp , (apply_node , i ))
393+
394+ for output in apply_node .outputs :
395+ self .setup_var (output )
396+ self .variables .add (output )
397+
398+ self .apply_nodes .add (apply_node )
399+
400+ if not hasattr (apply_node .tag , "imported_by" ):
401+ apply_node .tag .imported_by = []
402+
403+ apply_node .tag .imported_by .append (str (reason ))
404+
405+ self .execute_callbacks ("on_import" , apply_node , reason )
386406
387407 def change_node_input (
388408 self ,
0 commit comments