2222from typing import TYPE_CHECKING , Any , NamedTuple
2323
2424import nvidia .dali .backend_impl as _b
25+ import nvidia .dali .types as dali_types
26+ from nvidia .dali import fn
2527from nvidia .dali .external_source import ExternalSource
2628from nvidia .dali .pipeline import Pipeline
2729
@@ -74,6 +76,7 @@ class CompileNode:
7476 backend : str
7577 inputs : Sequence [CompileRef | Any ]
7678 kwargs : Mapping [str , CompileRef | Any ]
79+ kwarg_casts : dict [str , dali_types .DALIDataType ]
7780 num_outputs : int
7881 device : Device | None = None
7982 pipeline_output_offset : int | None = dataclasses .field (default = None , repr = False )
@@ -201,6 +204,25 @@ def make_source_batches(self, tensor_lists: Sequence[Any]) -> tuple[CompiledBatc
201204 for i , tl in enumerate (tensor_lists )
202205 )
203206
207+ @staticmethod
208+ def _compute_kwarg_casts (op : type ["Operator" ], raw_kwargs : Mapping [str , CompiledBatch | Any ]):
209+ casts : dict [str , dali_types .DALIDataType ] = {}
210+ schema = op ._schema
211+ assert schema is not None
212+
213+ for name , data in raw_kwargs .items ():
214+ if not isinstance (data , CompiledBatch ):
215+ continue
216+
217+ expected_type = schema .GetArgumentType (name )
218+ expected_type = dali_types ._vector_types .get (expected_type , expected_type )
219+ if expected_type == data .dtype .type_id :
220+ continue
221+
222+ casts [name ] = expected_type
223+
224+ return casts
225+
204226 @_nvtx_range ("Recording operator" )
205227 def record (
206228 self ,
@@ -209,18 +231,21 @@ def record(
209231 backend : str ,
210232 inputs : Sequence [CompileRef | Any ],
211233 kwargs : Mapping [str , CompileRef | Any ],
234+ raw_kwargs : Mapping [str , CompiledBatch | Any ],
212235 num_outputs : int ,
213236 device : Device | None = None ,
214237 ) -> CompileNode | None :
215238 if existing := self ._call_trie .find (call_chain ):
216239 if existing .inputs == inputs and existing .kwargs == kwargs :
217240 return existing
218241 return None
242+
219243 node = CompileNode (
220244 op_class = op_class ,
221245 backend = backend ,
222246 inputs = inputs ,
223247 kwargs = kwargs ,
248+ kwarg_casts = self ._compute_kwarg_casts (op_class , raw_kwargs ),
224249 num_outputs = num_outputs ,
225250 device = device ,
226251 )
@@ -466,6 +491,14 @@ def _wire_compile_graph(
466491 kw_scalars = {
467492 k : _scalar_decay (v ) for k , v in node .kwargs .items () if not isinstance (v , CompileRef )
468493 }
494+
495+ # Cast kwargs when necessary
496+ for name , dtype in node .kwarg_casts .items ():
497+ kw_nodes [name ] = fn .cast (kw_nodes [name ], dtype = dtype )
498+ # All kwargs need to be on the CPU
499+ for name , kw_node in kw_nodes .items ():
500+ kw_nodes [name ] = kw_node .cpu ()
501+
469502 op = node .op_class ._legacy_op (device = node .backend , ** kw_scalars )
470503 out = op (* positional , ** kw_nodes )
471504
@@ -535,6 +568,7 @@ def _call():
535568 backend = backend ,
536569 inputs = classified_inputs ,
537570 kwargs = classified_kwargs ,
571+ raw_kwargs = raw_kwargs ,
538572 num_outputs = len (results ),
539573 device = device ,
540574 )
0 commit comments