Skip to content

Commit 3af166b

Browse files
committed
Cast and copy kwargs automatically
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent 27e834e commit 3af166b

2 files changed

Lines changed: 57 additions & 1 deletion

File tree

dali/python/nvidia/dali/experimental/dynamic/_compile.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from typing import TYPE_CHECKING, Any, NamedTuple
2323

2424
import nvidia.dali.backend_impl as _b
25+
import nvidia.dali.types as dali_types
26+
from nvidia.dali import fn
2527
from nvidia.dali.external_source import ExternalSource
2628
from nvidia.dali.pipeline import Pipeline
2729

@@ -74,6 +76,7 @@ class CompileNode:
7476
backend: str
7577
inputs: Sequence[CompileRef | Any]
7678
kwargs: Mapping[str, CompileRef | Any]
79+
kwarg_casts: dict[str, dali_types.DALIDataType]
7780
num_outputs: int
7881
device: Device | None = None
7982
pipeline_output_offset: int | None = dataclasses.field(default=None, repr=False)
@@ -201,6 +204,25 @@ def make_source_batches(self, tensor_lists: Sequence[Any]) -> tuple[CompiledBatc
201204
for i, tl in enumerate(tensor_lists)
202205
)
203206

207+
@staticmethod
208+
def _compute_kwarg_casts(op: type["Operator"], raw_kwargs: Mapping[str, CompiledBatch | Any]):
209+
casts: dict[str, dali_types.DALIDataType] = {}
210+
schema = op._schema
211+
assert schema is not None
212+
213+
for name, data in raw_kwargs.items():
214+
if not isinstance(data, CompiledBatch):
215+
continue
216+
217+
expected_type = schema.GetArgumentType(name)
218+
expected_type = dali_types._vector_types.get(expected_type, expected_type)
219+
if expected_type == data.dtype.type_id:
220+
continue
221+
222+
casts[name] = expected_type
223+
224+
return casts
225+
204226
@_nvtx_range("Recording operator")
205227
def record(
206228
self,
@@ -209,18 +231,21 @@ def record(
209231
backend: str,
210232
inputs: Sequence[CompileRef | Any],
211233
kwargs: Mapping[str, CompileRef | Any],
234+
raw_kwargs: Mapping[str, CompiledBatch | Any],
212235
num_outputs: int,
213236
device: Device | None = None,
214237
) -> CompileNode | None:
215238
if existing := self._call_trie.find(call_chain):
216239
if existing.inputs == inputs and existing.kwargs == kwargs:
217240
return existing
218241
return None
242+
219243
node = CompileNode(
220244
op_class=op_class,
221245
backend=backend,
222246
inputs=inputs,
223247
kwargs=kwargs,
248+
kwarg_casts=self._compute_kwarg_casts(op_class, raw_kwargs),
224249
num_outputs=num_outputs,
225250
device=device,
226251
)
@@ -466,6 +491,14 @@ def _wire_compile_graph(
466491
kw_scalars = {
467492
k: _scalar_decay(v) for k, v in node.kwargs.items() if not isinstance(v, CompileRef)
468493
}
494+
495+
# Cast kwargs when necessary
496+
for name, dtype in node.kwarg_casts.items():
497+
kw_nodes[name] = fn.cast(kw_nodes[name], dtype=dtype)
498+
# All kwargs need to be on the CPU
499+
for name, kw_node in kw_nodes.items():
500+
kw_nodes[name] = kw_node.cpu()
501+
469502
op = node.op_class._legacy_op(device=node.backend, **kw_scalars)
470503
out = op(*positional, **kw_nodes)
471504

@@ -535,6 +568,7 @@ def _call():
535568
backend=backend,
536569
inputs=classified_inputs,
537570
kwargs=classified_kwargs,
571+
raw_kwargs=raw_kwargs,
538572
num_outputs=len(results),
539573
device=device,
540574
)

dali/test/python/experimental_mode/test_compile.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_compile_basic_pipeline():
8585

8686
assert len(dynamic_results) == len(compiled_results)
8787
for dyn, comp in zip(dynamic_results, compiled_results):
88-
np.testing.assert_array_almost_equal(dyn, comp)
88+
np.testing.assert_array_equal(dyn, comp)
8989

9090

9191
@eval_modes()
@@ -333,3 +333,25 @@ def make_reader():
333333
assert len(dynamic_results) == len(compiled_results)
334334
for dyn, comp in zip(dynamic_results, compiled_results):
335335
np.testing.assert_array_equal(dyn, comp)
336+
337+
338+
def test_compile_incompatible_kwarg_dtype():
339+
reader_dyn = ndd.readers.File(file_root=images_root)
340+
reader_comp = ndd.readers.File(file_root=images_root)
341+
342+
dynamic_results = []
343+
for jpegs, _ in reader_dyn.next_epoch(batch_size=4, compile=False):
344+
img = ndd.decoders.image(jpegs, device="gpu")
345+
resized = ndd.tensor_resize(img, sizes=ndd._shape(img))
346+
dynamic_results.append(ndd.as_tensor(resized, pad=True).cpu())
347+
348+
compiled_results = []
349+
for jpegs, _ in reader_comp.next_epoch(batch_size=4, compile=True):
350+
img = ndd.decoders.image(jpegs, device="gpu")
351+
resized = ndd.tensor_resize(img, sizes=ndd._shape(img))
352+
assert _is_compiled(resized), resized
353+
compiled_results.append(ndd.as_tensor(resized, pad=True).cpu())
354+
355+
assert len(dynamic_results) == len(compiled_results)
356+
for dyn, comp in zip(dynamic_results, compiled_results):
357+
np.testing.assert_array_equal(dyn, comp)

0 commit comments

Comments
 (0)