Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions cuda_core/cuda/core/experimental/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,34 +957,42 @@ def __new__(cls, device_id=None):

# important: creating a Device instance does not initialize the GPU!
if device_id is None:
device_id = handle_return(runtime.cudaGetDevice())
assert_type(device_id, int)
err, dev = driver.cuCtxGetDevice()
if err == 0:
Comment thread
leofang marked this conversation as resolved.
Outdated
device_id = int(dev)
else:
ctx = handle_return(driver.cuCtxGetCurrent())
Comment thread
leofang marked this conversation as resolved.
Comment thread
leofang marked this conversation as resolved.
Outdated
assert int(ctx) == 0
device_id = 0 # cudart behavior
assert isinstance(device_id, int), f"{device_id=}"
Comment thread
leofang marked this conversation as resolved.
Outdated
else:
total = handle_return(runtime.cudaGetDeviceCount())
assert_type(device_id, int)
if not (0 <= device_id < total):
total = handle_return(driver.cuDeviceGetCount())
Comment thread
leofang marked this conversation as resolved.
Outdated
if not isinstance(device_id, int) or not (0 <= device_id < total):
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")

# ensure Device is singleton
if not hasattr(_tls, "devices"):
total = handle_return(runtime.cudaGetDeviceCount())
total = handle_return(driver.cuDeviceGetCount())
Comment thread
leofang marked this conversation as resolved.
_tls.devices = []
for dev_id in range(total):
Comment thread
leofang marked this conversation as resolved.
dev = super().__new__(cls)

dev._id = dev_id
# If the device is in TCC mode, or does not support memory pools for some other reason,
# use the SynchronousMemoryResource which does not use memory pools.
if (
handle_return(
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
)
)
) == 1:
dev._mr = _DefaultAsyncMempool(dev_id)
else:
dev._mr = _SynchronousMemoryResource(dev_id)

dev._has_inited = False
dev._properties = None

_tls.devices.append(dev)

return _tls.devices[device_id]
Expand Down
Loading