1818from __future__ import annotations
1919
2020import abc
21+ import contextlib
2122import hashlib
2223import os
2324import pickle
2425import tempfile
2526import time
2627from pathlib import Path
27- from typing import Hashable , Iterable , Optional , Sequence
28+ from typing import Hashable , Iterable , Sequence
2829
2930from cuda .core ._module import ObjectCode
3031from cuda .core ._program import ProgramOptions
3132from cuda .core ._utils .cuda_utils import (
3233 driver as _driver ,
34+ )
35+ from cuda .core ._utils .cuda_utils import (
3336 handle_return as _handle_return ,
37+ )
38+ from cuda .core ._utils .cuda_utils import (
3439 nvrtc as _nvrtc ,
3540)
3641
4752
4853def _require_object_code (value : object ) -> ObjectCode :
4954 if not isinstance (value , ObjectCode ):
50- raise TypeError (
51- f"cache values must be ObjectCode instances, got { type (value ).__name__ } "
52- )
55+ raise TypeError (f"cache values must be ObjectCode instances, got { type (value ).__name__ } " )
5356 return value
5457
5558
@@ -131,10 +134,10 @@ def get(self, key: Hashable, default: ObjectCode | None = None) -> ObjectCode |
131134 except KeyError :
132135 return default
133136
134- def close (self ) -> None :
137+ def close (self ) -> None : # noqa: B027
135138 """Release backend resources. No-op by default."""
136139
137- def __enter__ (self ) -> " ProgramCacheResource" :
140+ def __enter__ (self ) -> ProgramCacheResource :
138141 return self
139142
140143 def __exit__ (self , exc_type , exc_value , traceback ) -> None :
@@ -214,24 +217,16 @@ def make_program_cache_key(
214217 A 32-byte blake2b digest suitable for use as a cache key.
215218 """
216219 if code_type not in _VALID_CODE_TYPES :
217- raise ValueError (
218- f"code_type={ code_type !r} is not supported "
219- f"(must be one of { sorted (_VALID_CODE_TYPES )} )"
220- )
220+ raise ValueError (f"code_type={ code_type !r} is not supported (must be one of { sorted (_VALID_CODE_TYPES )} )" )
221221 if target_type not in _VALID_TARGET_TYPES :
222- raise ValueError (
223- f"target_type={ target_type !r} is not supported "
224- f"(must be one of { sorted (_VALID_TARGET_TYPES )} )"
225- )
222+ raise ValueError (f"target_type={ target_type !r} is not supported (must be one of { sorted (_VALID_TARGET_TYPES )} )" )
226223
227224 if isinstance (code , str ):
228225 code_bytes = code .encode ("utf-8" )
229226 elif isinstance (code , (bytes , bytearray )):
230227 code_bytes = bytes (code )
231228 else :
232- raise TypeError (
233- f"code must be str or bytes, got { type (code ).__name__ } "
234- )
229+ raise TypeError (f"code must be str or bytes, got { type (code ).__name__ } " )
235230
236231 backend = _backend_for_code_type (code_type )
237232 # ProgramOptions.as_bytes may or may not accept target_type depending on
@@ -328,7 +323,7 @@ def __init__(
328323 self ,
329324 path : str | os .PathLike ,
330325 * ,
331- max_size_bytes : Optional [ int ] = None ,
326+ max_size_bytes : int | None = None ,
332327 ) -> None :
333328 if max_size_bytes is not None and max_size_bytes < 0 :
334329 raise ValueError ("max_size_bytes must be non-negative or None" )
@@ -338,7 +333,7 @@ def __init__(
338333 import sqlite3
339334
340335 self ._sqlite3 = sqlite3
341- self ._conn : Optional [ sqlite3 .Connection ] = None
336+ self ._conn : sqlite3 .Connection | None = None
342337 self ._open ()
343338
344339 # -- lifecycle -----------------------------------------------------------
@@ -396,29 +391,25 @@ def _require_open(self):
396391
397392 def __contains__ (self , key : object ) -> bool :
398393 k = _as_key_bytes (key )
399- row = self ._require_open ().execute (
400- "SELECT 1 FROM entries WHERE key = ?" , (k ,)
401- ).fetchone ()
394+ row = self ._require_open ().execute ("SELECT 1 FROM entries WHERE key = ?" , (k ,)).fetchone ()
402395 return row is not None
403396
404397 def __getitem__ (self , key : object ) -> ObjectCode :
405398 k = _as_key_bytes (key )
406399 conn = self ._require_open ()
407- row = conn .execute (
408- "SELECT payload FROM entries WHERE key = ?" , (k ,)
409- ).fetchone ()
400+ row = conn .execute ("SELECT payload FROM entries WHERE key = ?" , (k ,)).fetchone ()
410401 if row is None :
411402 raise KeyError (key )
412403 payload = row [0 ]
413404 try :
414- value = pickle .loads (payload )
405+ value = pickle .loads (payload ) # noqa: S301
415406 except Exception :
416407 # Corrupt entry -- delete and treat as a miss.
417408 conn .execute ("DELETE FROM entries WHERE key = ?" , (k ,))
418- raise KeyError (key )
409+ raise KeyError (key ) from None
419410 if not isinstance (value , ObjectCode ):
420411 conn .execute ("DELETE FROM entries WHERE key = ?" , (k ,))
421- raise KeyError (key )
412+ raise KeyError (key ) from None
422413 conn .execute (
423414 "UPDATE entries SET accessed_at = ? WHERE key = ?" ,
424415 (time .time (), k ),
@@ -452,9 +443,7 @@ def __delitem__(self, key: object) -> None:
452443 raise KeyError (key )
453444
454445 def __len__ (self ) -> int :
455- (n ,) = self ._require_open ().execute (
456- "SELECT COUNT(*) FROM entries"
457- ).fetchone ()
446+ (n ,) = self ._require_open ().execute ("SELECT COUNT(*) FROM entries" ).fetchone ()
458447 return int (n )
459448
460449 def clear (self ) -> None :
@@ -466,9 +455,7 @@ def _enforce_size_cap(self) -> None:
466455 if self ._max_size_bytes is None :
467456 return
468457 conn = self ._require_open ()
469- (total ,) = conn .execute (
470- "SELECT COALESCE(SUM(size_bytes), 0) FROM entries"
471- ).fetchone ()
458+ (total ,) = conn .execute ("SELECT COALESCE(SUM(size_bytes), 0) FROM entries" ).fetchone ()
472459 if total <= self ._max_size_bytes :
473460 return
474461 # Delete oldest (least-recently-used) until at or under the cap.
@@ -514,7 +501,7 @@ def __init__(
514501 self ,
515502 path : str | os .PathLike ,
516503 * ,
517- max_size_bytes : Optional [ int ] = None ,
504+ max_size_bytes : int | None = None ,
518505 ) -> None :
519506 if max_size_bytes is not None and max_size_bytes < 0 :
520507 raise ValueError ("max_size_bytes must be non-negative or None" )
@@ -548,29 +535,25 @@ def __getitem__(self, key: object) -> ObjectCode:
548535 try :
549536 data = path .read_bytes ()
550537 except FileNotFoundError :
551- raise KeyError (key )
538+ raise KeyError (key ) from None
552539 k = _as_key_bytes (key )
553540 try :
554- record = pickle .loads (data )
541+ record = pickle .loads (data ) # noqa: S301
555542 schema , stored_key , payload , _created_at = record
556543 if schema != _FILESTREAM_SCHEMA_VERSION :
557544 raise ValueError (f"unknown schema { schema } " )
558545 if stored_key != k :
559546 raise ValueError ("key mismatch" )
560- value = pickle .loads (payload )
547+ value = pickle .loads (payload ) # noqa: S301
561548 except Exception :
562549 # Corrupt entry -- delete and treat as a miss.
563- try :
550+ with contextlib . suppress ( FileNotFoundError ) :
564551 path .unlink ()
565- except FileNotFoundError :
566- pass
567- raise KeyError (key )
552+ raise KeyError (key ) from None
568553 if not isinstance (value , ObjectCode ):
569- try :
554+ with contextlib . suppress ( FileNotFoundError ) :
570555 path .unlink ()
571- except FileNotFoundError :
572- pass
573- raise KeyError (key )
556+ raise KeyError (key ) from None
574557 return value
575558
576559 def __setitem__ (self , key : object , value : object ) -> None :
@@ -594,10 +577,8 @@ def __setitem__(self, key: object, value: object) -> None:
594577 os .fsync (fh .fileno ())
595578 os .replace (tmp_path , target )
596579 except BaseException :
597- try :
580+ with contextlib . suppress ( FileNotFoundError ) :
598581 tmp_path .unlink ()
599- except FileNotFoundError :
600- pass
601582 raise
602583 self ._enforce_size_cap ()
603584
@@ -606,7 +587,7 @@ def __delitem__(self, key: object) -> None:
606587 try :
607588 path .unlink ()
608589 except FileNotFoundError :
609- raise KeyError (key )
590+ raise KeyError (key ) from None
610591
611592 def __len__ (self ) -> int :
612593 count = 0
@@ -616,18 +597,14 @@ def __len__(self) -> int:
616597
617598 def clear (self ) -> None :
618599 for path in list (self ._iter_entry_paths ()):
619- try :
600+ with contextlib . suppress ( FileNotFoundError ) :
620601 path .unlink ()
621- except FileNotFoundError :
622- pass
623602 # Remove empty subdirs (best-effort; concurrent writers may re-create).
624603 if self ._entries .exists ():
625604 for sub in sorted (self ._entries .iterdir (), reverse = True ):
626605 if sub .is_dir ():
627- try :
606+ with contextlib . suppress ( OSError ) :
628607 sub .rmdir ()
629- except OSError :
630- pass
631608
632609 # -- internals -----------------------------------------------------------
633610
@@ -659,8 +636,6 @@ def _enforce_size_cap(self) -> None:
659636 for _mtime , size , path in entries :
660637 if total <= self ._max_size_bytes :
661638 return
662- try :
639+ with contextlib . suppress ( FileNotFoundError ) :
663640 path .unlink ()
664641 total -= size
665- except FileNotFoundError :
666- pass
0 commit comments