Skip to content

Commit a7f7b41

Browse files
authored
prevent get_device_properties LRU cache from extending torch.tensor lifetimes (Dao-AILab#102)
1 parent 53d9af8 commit a7f7b41

1 file changed

Lines changed: 18 additions & 1 deletion

File tree

quack/cute_dsl_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _parse_arch_str(arch_str: str) -> Tuple[int, int]:
7777

7878

7979
@lru_cache
80-
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
80+
def _get_device_capacity_cached(device: torch.device = None) -> Tuple[int, int]:
8181
"""Return (major, minor) device capability.
8282
8383
Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
@@ -89,6 +89,23 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
8989
return torch.cuda.get_device_capability(device)
9090

9191

92+
def get_device_capacity(
93+
device: torch.device | torch.Tensor | None = None,
94+
) -> Tuple[int, int]:
95+
"""Return (major, minor) device capability.
96+
97+
Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
98+
without a GPU present.
99+
100+
Accepts either a ``torch.device`` or a tensor and canonicalizes to the
101+
underlying device before consulting the cached helper. This avoids leaking
102+
tensors through the LRU cache key.
103+
"""
104+
if isinstance(device, torch.Tensor):
105+
device = device.device
106+
return _get_device_capacity_cached(device)
107+
108+
92109
def _partition_fields(obj):
93110
"""Split dataclass fields into (constexpr_dict, non_constexpr_dict) by type."""
94111
all_fields = {field.name: getattr(obj, field.name) for field in fields(obj)}

0 commit comments

Comments
 (0)