Skip to content

Commit e1b8f7b

Browse files
authored
Fix numba-cuda installation for Python 3.9 (#6097)
- The latest numba-cuda release (0.21) is not compatible with Python 3.9. - This PR restricts the numba-cuda version to 0.20.x for Python 3.9, while allowing newer versions for other Python versions. Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent d3d98d2 commit e1b8f7b

2 files changed

Lines changed: 20 additions & 3 deletions

File tree

qa/TL0_python-self-test-operators_1/test_numba_cuda.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash -e
22
# used pip packages
3-
pip_packages='${python_test_runner_package} dataclasses numpy opencv-python pillow librosa scipy nvidia-ml-py==11.450.51 numba lz4 numba_cuda[cu${DALI_CUDA_MAJOR_VERSION}]>0.19.0'
3+
pip_packages='${python_test_runner_package} dataclasses numpy opencv-python pillow librosa scipy nvidia-ml-py==11.450.51 numba lz4 numba-cuda'
44

55
target_dir=./dali/test/python
66

qa/setup_packages.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_name(self, cuda_version=None, idx=None):
132132
cuda_version : str, optional, default = None
133133
Cuda version used for this query
134134
idx : int
135-
Index of name to retrive in case of specific version has different alias
135+
Index of name to retrieve in case of specific version has different alias
136136
"""
137137
name = BasePackage.get_alias(self.get_version(idx, cuda_version))
138138
if name is None:
@@ -181,7 +181,11 @@ def get_version(self, idx, cuda_version=None):
181181
if idx is None:
182182
idx = 0
183183
idx = self.clamp_index(idx, cuda_version)
184-
return self.get_all_versions(cuda_version)[idx]
184+
versions = self.get_all_versions(cuda_version)
185+
if len(versions):
186+
return versions[idx]
187+
else:
188+
return None
185189

186190
def get_all_versions(self, cuda_version=None):
187191
"""Get all versions compatible with provided cuda_version
@@ -579,6 +583,19 @@ def get_pyvers_name(self, url, cuda_version):
579583
]
580584
},
581585
),
586+
CudaPackage(
587+
"numba-cuda",
588+
{
589+
"120": [
590+
PckgVer(
591+
"0.20.1", python_min_ver="3.9", python_max_ver="3.9", dependencies=["numpy<2"]
592+
),
593+
PckgVer("0.21.1", python_min_ver="3.10", dependencies=["numpy<2"]),
594+
]
595+
},
596+
# name used during installation
597+
name="numba-cuda[cu{cuda_v[0]}{cuda_v[1]}]",
598+
),
582599
]
583600

584601
all_packages_keys = [pckg.key for pckg in all_packages]

0 commit comments

Comments
 (0)