22#
33# SPDX-License-Identifier: Apache-2.0
44
5+ import weakref
56from collections import namedtuple
67from typing import Optional , Union
78from warnings import warn
@@ -60,12 +61,12 @@ class KernelAttributes:
6061 def __new__ (self , * args , ** kwargs ):
6162 raise RuntimeError ("KernelAttributes cannot be instantiated directly. Please use Kernel APIs." )
6263
63- slots = ("_handle " , "_cache" , "_backend_version" , "_loader" )
64+ slots = ("_kernel " , "_cache" , "_backend_version" , "_loader" )
6465
6566 @classmethod
66- def _init (cls , handle ):
67+ def _init (cls , kernel ):
6768 self = super ().__new__ (cls )
68- self ._handle = handle
69+ self ._kernel = weakref . ref ( kernel )
6970 self ._cache = {}
7071
7172 self ._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000 ) else "old"
@@ -74,20 +75,23 @@ def _init(cls, handle):
7475
7576 def _get_cached_attribute (self , device_id : int , attribute : driver .CUfunction_attribute ) -> int :
7677 """Helper function to get a cached attribute or fetch and cache it if not present."""
77- if device_id in self ._cache and attribute in self ._cache [device_id ]:
78- return self ._cache [device_id ][attribute ]
78+ cache_key = device_id , attribute
79+ result = self ._cache .get (cache_key , cache_key )
80+ if result is not cache_key :
81+ return result
82+ kernel = self ._kernel ()
83+ if kernel is None :
84+ raise RuntimeError ("Cannot access kernel attributes for expired Kernel object" )
7985 if self ._backend_version == "new" :
80- result = handle_return (self ._loader ["attribute" ](attribute , self ._handle , device_id ))
86+ result = handle_return (self ._loader ["attribute" ](attribute , kernel ._handle , device_id ))
8187 else : # "old" backend
8288 warn (
8389 "Device ID argument is ignored when getting attribute from kernel when cuda version < 12. " ,
8490 RuntimeWarning ,
8591 stacklevel = 2 ,
8692 )
87- result = handle_return (self ._loader ["attribute" ](attribute , self ._handle ))
88- if device_id not in self ._cache :
89- self ._cache [device_id ] = {}
90- self ._cache [device_id ][attribute ] = result
93+ result = handle_return (self ._loader ["attribute" ](attribute , kernel ._handle ))
94+ self ._cache [cache_key ] = result
9195 return result
9296
9397 def max_threads_per_block (self , device_id : int = None ) -> int :
@@ -365,7 +369,7 @@ class Kernel:
365369
366370 """
367371
368- __slots__ = ("_handle" , "_module" , "_attributes" , "_occupancy" )
372+ __slots__ = ("_handle" , "_module" , "_attributes" , "_occupancy" , "__weakref__" )
369373
370374 def __new__ (self , * args , ** kwargs ):
371375 raise RuntimeError ("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs." )
@@ -385,7 +389,7 @@ def _from_obj(cls, obj, mod):
385389 def attributes (self ) -> KernelAttributes :
386390 """Get the read-only attributes of this kernel."""
387391 if self ._attributes is None :
388- self ._attributes = KernelAttributes ._init (self . _handle )
392+ self ._attributes = KernelAttributes ._init (self )
389393 return self ._attributes
390394
391395 def _get_arguments_info (self , param_info = False ) -> tuple [int , list [ParamInfo ]]:
0 commit comments