@@ -77,7 +77,7 @@ def _parse_arch_str(arch_str: str) -> Tuple[int, int]:
7777
7878
7979@lru_cache
80- def get_device_capacity (device : torch .device = None ) -> Tuple [int , int ]:
80+ def _get_device_capacity_cached (device : torch .device = None ) -> Tuple [int , int ]:
8181 """Return (major, minor) device capability.
8282
8383 Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
@@ -89,6 +89,23 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
8989 return torch .cuda .get_device_capability (device )
9090
9191
92+ def get_device_capacity (
93+ device : torch .device | torch .Tensor | None = None ,
94+ ) -> Tuple [int , int ]:
95+ """Return (major, minor) device capability.
96+
97+ Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
98+ without a GPU present.
99+
100+ Accepts either a ``torch.device`` or a tensor and canonicalizes to the
101+ underlying device before consulting the cached helper. This avoids leaking
102+ tensors through the LRU cache key.
103+ """
104+ if isinstance (device , torch .Tensor ):
105+ device = device .device
106+ return _get_device_capacity_cached (device )
107+
108+
92109def _partition_fields (obj ):
93110 """Split dataclass fields into (constexpr_dict, non_constexpr_dict) by type."""
94111 all_fields = {field .name : getattr (obj , field .name ) for field in fields (obj )}
0 commit comments