Skip to content

Commit 7ae8b15

Browse files
authored
Fix constructing GPU Tensor from DLPack capsule (#3711)
- when the DLPack capsule is provided and GPU Tensor is created with provided deivce_id DALI attempts to treat it as an object implementing __cuda_arra_interface__. This PR make sure the right GPU Tensor constructor is called depending if it is DLPack of __cuda_arra_interface__ object - fixes a regression from #3710 Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent 8d451d8 commit 7ae8b15

1 file changed

Lines changed: 20 additions & 7 deletions

File tree

dali/python/nvidia/dali/external_source.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,18 @@ def to_numpy(x):
7474
inputs = []
7575
checked = False
7676
for datum in data:
77-
info = _b.CheckDLPackCapsule(datum)
78-
if not info[0] and not checked:
77+
(is_dlpack, is_gpu_data) = _b.CheckDLPackCapsule(datum)
78+
if not is_dlpack and not checked:
7979
_check_data_batch(data, batch_size, layout)
8080
checked = True
8181
if isinstance(datum, (_tensors.TensorCPU, _tensors.TensorGPU)):
8282
inp = type(datum)(datum, layout=layout) if layout is not None else datum
83-
elif hasattr(datum, "__cuda_array_interface__") or (info[0] and info[1]):
83+
elif is_dlpack:
84+
if is_gpu_data:
85+
inp = _tensors.TensorGPU(datum, layout or "")
86+
else:
87+
inp = _tensors.TensorCPU(datum, layout or "")
88+
elif hasattr(datum, "__cuda_array_interface__"):
8489
array_device_id = _types._get_device_id_for_array(datum)
8590
if array_device_id is None:
8691
array_device_id = device_id
@@ -93,11 +98,19 @@ def to_numpy(x):
9398
"Mixed input types are not support, all need to reside on the CPU or GPU"
9499
data = inputs
95100
else:
96-
info = _b.CheckDLPackCapsule(data)
97-
if not info[0]:
101+
(is_dlpack, is_gpu_data) = _b.CheckDLPackCapsule(data)
102+
if not is_dlpack:
98103
_check_data_batch(data, batch_size, layout)
99-
if hasattr(data, "__cuda_array_interface__") or (info[0] and info[1]):
100-
data = _tensors.TensorListGPU(data, layout or "")
104+
if hasattr(data, "__cuda_array_interface__"):
105+
array_device_id = _types._get_device_id_for_array(data)
106+
if array_device_id is None:
107+
array_device_id = device_id
108+
data = _tensors.TensorListGPU(data, layout or "", array_device_id)
109+
elif is_dlpack:
110+
if is_gpu_data:
111+
data = _tensors.TensorListGPU(data, layout or "")
112+
else:
113+
data = _tensors.TensorListCPU(data, layout or "")
101114
else:
102115
data = to_numpy(data)
103116
data = _tensors.TensorListCPU(data, layout or "")

0 commit comments

Comments
 (0)