Skip to content

Commit 74798e7

Browse files
leofangemcastilloclaude
committed
Cache type check in _is_torch_tensor for ~20% speedup
Cache the result of the torch tensor type check (module + hasattr + version) keyed by type(obj). Subsequent calls for the same type are a single dict lookup (~76 ns) instead of the full check (~186 ns). Non-torch objects also benefit as the cache returns False immediately after the first miss. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0f57646 commit 74798e7

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ from cuda.core._memory import Buffer
3636
# ---------------------------------------------------------------------------
3737

3838
cdef object _tensor_bridge = None
39+
# Cache: type(obj) -> True/False for the torch tensor check.
40+
# Once a type is seen, we never re-check.
41+
cdef dict _torch_type_cache = {}
3942
# Tri-state: None = not checked, True/False = result of version check
4043
cdef object _torch_version_ok = None
4144

@@ -58,9 +61,15 @@ cdef inline bint _torch_version_check():
5861

5962

6063
cdef inline bint _is_torch_tensor(object obj):
61-
cdef str mod = type(obj).__module__ or ""
62-
return mod.startswith("torch") and hasattr(obj, "data_ptr") \
64+
cdef type tp = type(obj)
65+
cdef object cached = _torch_type_cache.get(tp)
66+
if cached is not None:
67+
return <bint>cached
68+
cdef str mod = tp.__module__ or ""
69+
cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \
6370
and _torch_version_check()
71+
_torch_type_cache[tp] = result
72+
return result
6473

6574

6675
cdef object _get_tensor_bridge():

0 commit comments

Comments
 (0)