22# SPDX-License-Identifier: Apache-2.0
33
44import importlib
5- import json
65import os
76from pathlib import Path
87
@@ -36,10 +35,22 @@ def _default_process_wide_guard_rails_mode(monkeypatch):
3635 monkeypatch .delenv (COMPATIBILITY_GUARD_RAILS_ENV_VAR , raising = False )
3736
3837
39- def _write_version_json (ctk_root : Path , toolkit_version : str ) -> None :
40- ctk_root .mkdir (parents = True , exist_ok = True )
41- payload = {"cuda" : {"version" : toolkit_version }}
42- (ctk_root / "version.json" ).write_text (json .dumps (payload ), encoding = "utf-8" )
38+ def _write_cuda_h (
39+ ctk_root : Path ,
40+ toolkit_version : str ,
41+ * ,
42+ include_dir_parts : tuple [str , ...] = ("targets" , "x86_64-linux" , "include" ),
43+ ) -> None :
44+ parts = toolkit_version .split ("." )
45+ if len (parts ) < 2 :
46+ raise AssertionError (f"Expected at least major.minor in toolkit version, got { toolkit_version !r} " )
47+ encoded = int (parts [0 ]) * 1000 + int (parts [1 ]) * 10
48+ cuda_h_path = ctk_root .joinpath (* include_dir_parts , "cuda.h" )
49+ cuda_h_path .parent .mkdir (parents = True , exist_ok = True )
50+ cuda_h_path .write_text (
51+ f"#ifndef CUDA_H\n #define CUDA_H\n #define CUDA_VERSION { encoded } \n #endif\n " ,
52+ encoding = "utf-8" ,
53+ )
4354
4455
4556def _touch (path : Path ) -> str :
@@ -76,11 +87,7 @@ def _located_bitcode_lib(name: str, abs_path: str) -> LocatedBitcodeLib:
7687
7788
7889def _driver_cuda_version (encoded : int ) -> DriverCudaVersion :
79- return DriverCudaVersion (
80- encoded = encoded ,
81- major = encoded // 1000 ,
82- minor = (encoded % 1000 ) // 10 ,
83- )
90+ return DriverCudaVersion .from_encoded (encoded )
8491
8592
8693class _FakeDistribution :
@@ -111,8 +118,9 @@ def _assert_real_ctk_backed_path(path: str) -> None:
111118 if current .is_file ():
112119 current = current .parent
113120 for candidate in (current , * current .parents ):
114- version_json_path = candidate / "version.json"
115- if version_json_path .is_file ():
121+ if (candidate / "include" / "cuda.h" ).is_file ():
122+ return
123+ if any (path .is_file () for path in (candidate / "targets" ).glob ("*/include/cuda.h" )):
116124 return
117125 for env_var in ("CUDA_PATH" , "CUDA_HOME" ):
118126 ctk_root = os .environ .get (env_var )
@@ -122,7 +130,7 @@ def _assert_real_ctk_backed_path(path: str) -> None:
122130 if os .path .commonpath ((norm_path , norm_ctk_root )) == norm_ctk_root :
123131 return
124132 raise AssertionError (
125- "Expected a site-packages path, a path under a CTK root with version.json , "
133+ "Expected a site-packages path, a path under a CTK root with cuda.h , "
126134 f"or a path under CUDA_PATH/CUDA_HOME, got { path !r} "
127135 )
128136
@@ -220,7 +228,7 @@ def fail_raw_fallback(_libname: str) -> LoadedDL:
220228
221229@pytest .mark .parametrize ("env_value" , [None , "" ])
222230def test_public_apis_default_to_strict_when_env_var_is_unset_or_empty (monkeypatch , tmp_path , env_value ):
223- lib_path = _touch (tmp_path / "no-version-json " / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
231+ lib_path = _touch (tmp_path / "no-cuda-h " / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
224232
225233 monkeypatch .setattr (compatibility_module , "_load_nvidia_dynamic_lib" , lambda _libname : _loaded_dl (lib_path ))
226234 monkeypatch .setattr (
@@ -238,12 +246,12 @@ def fail_raw_fallback(_libname: str) -> LoadedDL:
238246 else :
239247 monkeypatch .setenv (COMPATIBILITY_GUARD_RAILS_ENV_VAR , env_value )
240248
241- with pytest .raises (CompatibilityInsufficientMetadataError , match = "version.json " ):
249+ with pytest .raises (CompatibilityInsufficientMetadataError , match = "cuda.h " ):
242250 pathfinder .load_nvidia_dynamic_lib ("nvrtc" )
243251
244252
245253def test_public_apis_best_effort_fall_back_on_insufficient_metadata (monkeypatch , tmp_path ):
246- guarded_lib_path = _touch (tmp_path / "no-version-json " / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
254+ guarded_lib_path = _touch (tmp_path / "no-cuda-h " / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
247255 raw_loaded = _loaded_dl ("/opt/mock/libnvrtc.so.12" , found_via = "system-search" )
248256
249257 monkeypatch .setenv (COMPATIBILITY_GUARD_RAILS_ENV_VAR , "best_effort" )
@@ -292,8 +300,8 @@ def test_public_apis_reject_invalid_guard_rails_mode(monkeypatch):
292300def test_public_apis_share_process_wide_guard_rails_state (monkeypatch , tmp_path ):
293301 lib_root = tmp_path / "cuda-12.8"
294302 hdr_root = tmp_path / "cuda-12.9"
295- _write_version_json (lib_root , "12.8.20250303" )
296- _write_version_json (hdr_root , "12.9.20250531" )
303+ _write_cuda_h (lib_root , "12.8.20250303" )
304+ _write_cuda_h (hdr_root , "12.9.20250531" )
297305
298306 lib_path = _touch (lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
299307 hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include"
@@ -320,7 +328,7 @@ def test_public_apis_share_process_wide_guard_rails_state(monkeypatch, tmp_path)
320328
321329def test_load_dynamic_lib_then_find_headers_same_ctk_version (monkeypatch , tmp_path ):
322330 ctk_root = tmp_path / "cuda-12.9"
323- _write_version_json (ctk_root , "12.9.20250531" )
331+ _write_cuda_h (ctk_root , "12.9.20250531" )
324332 lib_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
325333 hdr_dir = ctk_root / "targets" / "x86_64-linux" / "include"
326334 _touch (hdr_dir / "nvrtc.h" )
@@ -344,8 +352,8 @@ def test_load_dynamic_lib_then_find_headers_same_ctk_version(monkeypatch, tmp_pa
344352def test_exact_ctk_major_minor_match_is_required (monkeypatch , tmp_path ):
345353 lib_root = tmp_path / "cuda-12.8"
346354 hdr_root = tmp_path / "cuda-12.9"
347- _write_version_json (lib_root , "12.8.20250303" )
348- _write_version_json (hdr_root , "12.9.20250531" )
355+ _write_cuda_h (lib_root , "12.8.20250303" )
356+ _write_cuda_h (hdr_root , "12.9.20250531" )
349357
350358 lib_path = _touch (lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
351359 hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include"
@@ -367,7 +375,7 @@ def test_exact_ctk_major_minor_match_is_required(monkeypatch, tmp_path):
367375
368376def test_driver_major_must_not_be_older_than_ctk_major (monkeypatch , tmp_path ):
369377 ctk_root = tmp_path / "cuda-13.0"
370- _write_version_json (ctk_root , "13.0.20251003" )
378+ _write_cuda_h (ctk_root , "13.0.20251003" )
371379 lib_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.13" )
372380
373381 monkeypatch .setattr (compatibility_module , "_load_nvidia_dynamic_lib" , lambda _libname : _loaded_dl (lib_path ))
@@ -378,17 +386,31 @@ def test_driver_major_must_not_be_older_than_ctk_major(monkeypatch, tmp_path):
378386 guard_rails .load_nvidia_dynamic_lib ("nvrtc" )
379387
380388
381- def test_missing_version_json_raises_insufficient_metadata (monkeypatch , tmp_path ):
382- lib_path = _touch (tmp_path / "no-version-json " / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
389+ def test_missing_cuda_h_raises_insufficient_metadata (monkeypatch , tmp_path ):
390+ lib_path = _touch (tmp_path / "no-cuda-h " / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
383391
384392 monkeypatch .setattr (compatibility_module , "_load_nvidia_dynamic_lib" , lambda _libname : _loaded_dl (lib_path ))
385393
386394 guard_rails = CompatibilityGuardRails (driver_cuda_version = _driver_cuda_version (13000 ))
387395
388- with pytest .raises (CompatibilityInsufficientMetadataError , match = "version.json " ):
396+ with pytest .raises (CompatibilityInsufficientMetadataError , match = "cuda.h " ):
389397 guard_rails .load_nvidia_dynamic_lib ("nvrtc" )
390398
391399
400+ def test_windows_style_ctk_root_uses_root_include_cuda_h (monkeypatch , tmp_path ):
401+ ctk_root = tmp_path / "cuda-13.2"
402+ _write_cuda_h (ctk_root , "13.2.20251003" , include_dir_parts = ("include" ,))
403+ lib_path = _touch (ctk_root / "bin" / "x64" / "nvrtc64_130_0.dll" )
404+
405+ monkeypatch .setattr (compatibility_module , "_load_nvidia_dynamic_lib" , lambda _libname : _loaded_dl (lib_path ))
406+
407+ guard_rails = CompatibilityGuardRails (driver_cuda_version = _driver_cuda_version (13000 ))
408+
409+ loaded = guard_rails .load_nvidia_dynamic_lib ("nvrtc" )
410+
411+ assert loaded .abs_path == lib_path
412+
413+
392414def test_other_packaging_raises_insufficient_metadata (monkeypatch , tmp_path ):
393415 abs_path = _touch (tmp_path / "site-packages" / "nvidia" / "nvshmem" / "lib" / "libnvshmem_device.bc" )
394416
@@ -407,7 +429,7 @@ def test_other_packaging_raises_insufficient_metadata(monkeypatch, tmp_path):
407429def test_driver_libs_do_not_lock_ctk_anchor (monkeypatch , tmp_path ):
408430 driver_lib_path = _touch (tmp_path / "driver-root" / "libnvidia-ml.so.1" )
409431 ctk_root = tmp_path / "cuda-12.9"
410- _write_version_json (ctk_root , "12.9.20250531" )
432+ _write_cuda_h (ctk_root , "12.9.20250531" )
411433 ctk_lib_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
412434
413435 def fake_load_nvidia_dynamic_lib (libname : str ) -> LoadedDL :
@@ -432,8 +454,8 @@ def test_driver_libs_do_not_mask_later_ctk_mismatch(monkeypatch, tmp_path):
432454 driver_lib_path = _touch (tmp_path / "driver-root" / "libnvidia-ml.so.1" )
433455 lib_root = tmp_path / "cuda-12.8"
434456 hdr_root = tmp_path / "cuda-12.9"
435- _write_version_json (lib_root , "12.8.20250303" )
436- _write_version_json (hdr_root , "12.9.20250531" )
457+ _write_cuda_h (lib_root , "12.8.20250303" )
458+ _write_cuda_h (hdr_root , "12.9.20250531" )
437459
438460 lib_path = _touch (lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
439461 hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include"
@@ -506,7 +528,7 @@ def test_wheel_metadata_accepts_exact_and_range_requirements(monkeypatch, tmp_pa
506528
507529def test_constraints_accept_string_and_tuple_forms (monkeypatch , tmp_path ):
508530 ctk_root = tmp_path / "cuda-12.9"
509- _write_version_json (ctk_root , "12.9.20250531" )
531+ _write_cuda_h (ctk_root , "12.9.20250531" )
510532 lib_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
511533
512534 monkeypatch .setattr (compatibility_module , "_load_nvidia_dynamic_lib" , lambda _libname : _loaded_dl (lib_path ))
@@ -524,7 +546,7 @@ def test_constraints_accept_string_and_tuple_forms(monkeypatch, tmp_path):
524546
525547def test_constraint_failure_raises (monkeypatch , tmp_path ):
526548 ctk_root = tmp_path / "cuda-12.9"
527- _write_version_json (ctk_root , "12.9.20250531" )
549+ _write_cuda_h (ctk_root , "12.9.20250531" )
528550 lib_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
529551
530552 monkeypatch .setattr (compatibility_module , "_load_nvidia_dynamic_lib" , lambda _libname : _loaded_dl (lib_path ))
@@ -541,7 +563,7 @@ def test_constraint_failure_raises(monkeypatch, tmp_path):
541563
542564def test_static_bitcode_and_binary_methods_participate_in_checks (monkeypatch , tmp_path ):
543565 ctk_root = tmp_path / "cuda-12.9"
544- _write_version_json (ctk_root , "12.9.20250531" )
566+ _write_cuda_h (ctk_root , "12.9.20250531" )
545567
546568 lib_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
547569 static_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libcudadevrt.a" )
@@ -575,7 +597,7 @@ def test_static_bitcode_and_binary_methods_participate_in_checks(monkeypatch, tm
575597
576598def test_guard_rails_query_driver_cuda_version_by_default (monkeypatch , tmp_path ):
577599 ctk_root = tmp_path / "cuda-12.9"
578- _write_version_json (ctk_root , "12.9.20250531" )
600+ _write_cuda_h (ctk_root , "12.9.20250531" )
579601 lib_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
580602
581603 query_calls : list [int ] = []
@@ -598,7 +620,7 @@ def fake_query_driver_cuda_version() -> DriverCudaVersion:
598620
599621def test_guard_rails_wrap_driver_query_failures (monkeypatch , tmp_path ):
600622 ctk_root = tmp_path / "cuda-12.9"
601- _write_version_json (ctk_root , "12.9.20250531" )
623+ _write_cuda_h (ctk_root , "12.9.20250531" )
602624 lib_path = _touch (ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12" )
603625
604626 monkeypatch .setattr (compatibility_module , "_load_nvidia_dynamic_lib" , lambda _libname : _loaded_dl (lib_path ))
0 commit comments