5353ConstraintArg : TypeAlias = int | str | tuple [str , int ] | None
5454
5555_CTK_VERSION_RE = re .compile (r"^(?P<major>\d+)\.(?P<minor>\d+)" )
56- _REQUIRES_DIST_RE = re .compile (
57- r"^\s*(?P<name>[A-Za-z0-9_.-]+)\s*==\s*(?P<version>[0-9][A-Za-z0-9.+-]*?)(?:\.\*)?(?:\s*;|$)"
58- )
56+ _REQUIRES_DIST_RE = re .compile (r"^\s*(?P<name>[A-Za-z0-9_.-]+)\s*(?P<specifier_text>[^;]*)(?:\s*;|$)" )
57+ _VERSION_SPECIFIER_RE = re .compile (r"^\s*(?P<operator>==|<=|>=|<|>)\s*(?P<version>[0-9][A-Za-z0-9.+-]*?(?:\.\*)?)\s*$" )
5958
6059_STATIC_LIBS_PACKAGED_WITH : dict [str , PackagedWith ] = {
6160 "cudadevrt" : "ctk" ,
@@ -113,6 +112,12 @@ def __str__(self) -> str:
113112 return f"{ self .operator } { self .value } "
114113
115114
115+ @dataclass (frozen = True , slots = True )
116+ class VersionSpecifier :
117+ operator : ConstraintOperator
118+ version : str
119+
120+
116121@dataclass (frozen = True , slots = True )
117122class ResolvedItem :
118123 name : str
@@ -185,6 +190,63 @@ def _distribution_name(dist: importlib.metadata.Distribution) -> str | None:
185190 return metadata .get ("Name" )
186191
187192
193+ def _release_version_parts (version : str ) -> tuple [int , ...] | None :
194+ match = re .match (r"^\d+(?:\.\d+)*" , version )
195+ if match is None :
196+ return None
197+ return tuple (int (part ) for part in match .group (0 ).split ("." ))
198+
199+
200+ def _compare_release_versions (lhs : tuple [int , ...], rhs : tuple [int , ...]) -> int :
201+ max_len = max (len (lhs ), len (rhs ))
202+ lhs_padded = lhs + (0 ,) * (max_len - len (lhs ))
203+ rhs_padded = rhs + (0 ,) * (max_len - len (rhs ))
204+ if lhs_padded < rhs_padded :
205+ return - 1
206+ if lhs_padded > rhs_padded :
207+ return 1
208+ return 0
209+
210+
211+ def _parse_version_specifiers (specifier_text : str ) -> tuple [VersionSpecifier , ...]:
212+ stripped = specifier_text .strip ()
213+ if not stripped :
214+ return ()
215+ parsed : list [VersionSpecifier ] = []
216+ for raw_clause in stripped .split ("," ):
217+ match = _VERSION_SPECIFIER_RE .match (raw_clause )
218+ if match is None :
219+ return ()
220+ parsed .append (VersionSpecifier (operator = match .group ("operator" ), version = match .group ("version" )))
221+ return tuple (parsed )
222+
223+
224+ def _version_satisfies_specifiers (version : str , specifiers : tuple [VersionSpecifier , ...]) -> bool :
225+ if not specifiers :
226+ return False
227+ for specifier in specifiers :
228+ if specifier .operator == "==" :
229+ prefix = specifier .version .removesuffix (".*" )
230+ if version == prefix or version .startswith (prefix + "." ):
231+ continue
232+ return False
233+ candidate_parts = _release_version_parts (version )
234+ required_parts = _release_version_parts (specifier .version )
235+ if candidate_parts is None or required_parts is None :
236+ return False
237+ comparison = _compare_release_versions (candidate_parts , required_parts )
238+ if specifier .operator == "<" and comparison < 0 :
239+ continue
240+ if specifier .operator == "<=" and comparison <= 0 :
241+ continue
242+ if specifier .operator == ">" and comparison > 0 :
243+ continue
244+ if specifier .operator == ">=" and comparison >= 0 :
245+ continue
246+ return False
247+ return True
248+
249+
188250@functools .cache
189251def _owned_distribution_candidates (abs_path : str ) -> tuple [tuple [str , str ], ...]:
190252 normalized_abs_path = os .path .normpath (os .path .abspath (abs_path ))
@@ -201,27 +263,42 @@ def _owned_distribution_candidates(abs_path: str) -> tuple[tuple[str, str], ...]
201263
202264
203265@functools .cache
204- def _cuda_toolkit_requirement_maps () -> tuple [tuple [str , CtkVersion , dict [str , tuple [str , ...]]], ...]:
205- results : list [tuple [str , CtkVersion , dict [str , tuple [str , ...]]]] = []
266+ def _cuda_toolkit_requirement_maps () -> tuple [
267+ tuple [str , CtkVersion , dict [str , tuple [tuple [VersionSpecifier , ...], ...]]], ...
268+ ]:
269+ results : list [tuple [str , CtkVersion , dict [str , tuple [tuple [VersionSpecifier , ...], ...]]]] = []
206270 for dist in importlib .metadata .distributions ():
207271 dist_name = _distribution_name (dist )
208272 if _normalize_distribution_name (dist_name or "" ) != "cuda-toolkit" :
209273 continue
210274 ctk_version = _parse_ctk_version (dist .version )
211275 if ctk_version is None :
212276 continue
213- requirement_map : dict [str , set [str ]] = {}
277+ requirement_map : dict [str , set [tuple [ VersionSpecifier , ...] ]] = {}
214278 for requirement in dist .requires or ():
215279 match = _REQUIRES_DIST_RE .match (requirement )
216280 if match is None :
217281 continue
218282 req_name = _normalize_distribution_name (match .group ("name" ))
219- requirement_map .setdefault (req_name , set ()).add (match .group ("version" ))
283+ parsed_specifiers = _parse_version_specifiers (match .group ("specifier_text" ))
284+ if not parsed_specifiers :
285+ continue
286+ requirement_map .setdefault (req_name , set ()).add (parsed_specifiers )
220287 results .append (
221288 (
222289 dist .version ,
223290 ctk_version ,
224- {name : tuple (sorted (prefixes )) for name , prefixes in requirement_map .items ()},
291+ {
292+ name : tuple (
293+ sorted (
294+ specifier_sets ,
295+ key = lambda specifiers : tuple (
296+ (specifier .operator , specifier .version ) for specifier in specifiers
297+ ),
298+ )
299+ )
300+ for name , specifier_sets in requirement_map .items ()
301+ },
225302 )
226303 )
227304 return tuple (results )
@@ -232,9 +309,9 @@ def _wheel_metadata_for_abs_path(abs_path: str) -> CtkMetadata | None:
232309 for owner_name , owner_version in _owned_distribution_candidates (abs_path ):
233310 normalized_owner_name = _normalize_distribution_name (owner_name )
234311 for toolkit_dist_version , ctk_version , requirement_map in _cuda_toolkit_requirement_maps ():
235- requirement_prefixes = requirement_map .get (normalized_owner_name , ())
312+ requirement_specifier_sets = requirement_map .get (normalized_owner_name , ())
236313 if not any (
237- owner_version == prefix or owner_version . startswith ( prefix + "." ) for prefix in requirement_prefixes
314+ _version_satisfies_specifiers ( owner_version , specifiers ) for specifiers in requirement_specifier_sets
238315 ):
239316 continue
240317 matched_versions [ctk_version ] = (
0 commit comments