Skip to content

Commit 50ed678

Browse files
aryanputtaleofang
andauthored
Validate checkpoint GPU UUID inputs early (#2086)
* Validate checkpoint GPU UUID inputs Signed-off-by: Aryan <aryansputta@gmail.com> * Narrow checkpoint GPU UUID restore inputs * Restore checkpoint CUuuid compatibility * Retry CI after infra failures Signed-off-by: Aryan Putta <aryansputta@gmail.com> --------- Signed-off-by: Aryan <aryansputta@gmail.com> Signed-off-by: Aryan Putta <aryansputta@gmail.com> Co-authored-by: Leo Fang <leof@nvidia.com>
1 parent c270969 commit 50ed678

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

cuda_core/cuda/core/checkpoint.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,14 +234,21 @@ def _as_cuuuid(driver: Any, value: Any, buffers: list[Any]) -> Any:
234234
the ``"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`` format returned by
235235
:attr:`Device.uuid`.
236236
"""
237+
if isinstance(value, driver.CUuuid):
238+
return value
237239
if isinstance(value, str):
238-
raw = bytes.fromhex(value.replace("-", ""))
240+
try:
241+
raw = bytes.fromhex(value.replace("-", ""))
242+
except ValueError:
243+
raise ValueError(
244+
f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}"
245+
) from None
239246
if len(raw) != 16:
240247
raise ValueError(f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}")
241248
buf = _ctypes.create_string_buffer(raw, 16)
242249
buffers.append(buf)
243250
return driver.CUuuid(_ctypes.addressof(buf))
244-
return value
251+
raise TypeError("GPU UUID values must be CUDA UUID objects or UUID strings")
245252

246253

247254
__all__ = [

0 commit comments

Comments
 (0)