2121"""
2222
2323import pytest
24+ import warnings
25+
26+ warnings .filterwarnings (
27+ "ignore" , message = "builtin type swigvarlink has no __module__ attribute" , category = DeprecationWarning
28+ )
29+ warnings .filterwarnings (
30+ "ignore" , message = "builtin type SwigPyPacked has no __module__ attribute" , category = DeprecationWarning
31+ )
32+ warnings .filterwarnings (
33+ "ignore" , message = "builtin type SwigPyObject has no __module__ attribute" , category = DeprecationWarning
34+ )
2435import jax
36+ import os
2537import importlib .util
2638
39+ # Force early JAX initialization on GPU to prevent CUDA context conflicts with TensorFlow/PyTorch.
40+ # If JAX initialization is deferred, TensorFlow/PyTorch (imported during test collection)
41+ # might initialize CUDA first, causing JAX's subsequent NCCL communicator creation to fail
42+ # with 'corrupted comm object detected'.
43+ # Detect GPU environment using standard JAX env vars, GHA runner device types,
44+ # and nvidia-docker visible device markers.
45+ _jax_platforms = os .getenv ("JAX_PLATFORMS" , "" ).lower ()
46+ _device_type = os .getenv ("INPUTS_DEVICE_TYPE" , "" ).lower ()
47+ _has_gpu = (
48+ "cuda" in _jax_platforms
49+ or "gpu" in _jax_platforms
50+ or "cuda" in _device_type
51+ or "gpu" in _device_type
52+ or os .getenv ("CUDA_VISIBLE_DEVICES" ) is not None
53+ or os .getenv ("NVIDIA_VISIBLE_DEVICES" ) is not None
54+ )
55+ if _has_gpu :
56+ try :
57+ _ = jax .devices ()
58+ except Exception : # pylint: disable=broad-exception-caught
59+ pass
60+
2761# --- Monkeypatch for absl.testing.parameterized ---
2862# Context: Decorating a test method with @parameterized.named_parameters returns a custom
2963# iterable container (_ParameterizedTestIter) instead of a standard function object.
@@ -66,22 +100,11 @@ def _custom_iter(self):
66100except AttributeError :
67101 pass
68102
69- import os
70103
71104if os .getenv ("JAX_PLATFORMS" ) == "proxy" :
72105 # Import maxtext early to register the pathways proxy backend before JAX is queried.
73106 import maxtext # pylint: disable=unused-import
74107
75- try :
76- _HAS_TPU = any (d .platform == "tpu" for d in jax .devices ())
77- except Exception : # pragma: no cover pylint: disable=broad-exception-caught
78- _HAS_TPU = False
79-
80- try :
81- _HAS_GPU = any (d .platform == "gpu" for d in jax .devices ())
82- except Exception : # pragma: no cover pylint: disable=broad-exception-caught
83- _HAS_GPU = False
84-
85108from maxtext .common .gcloud_stub import is_decoupled
86109
87110# Configure JAX to use unsafe_rbg PRNG implementation to match main scripts.
@@ -121,15 +144,7 @@ def pytest_collection_modifyitems(config, items):
121144 remaining = []
122145 deselected = []
123146
124- skip_no_tpu = None
125- skip_no_gpu = None
126147 skip_no_tpu_backend = None
127- if not _HAS_TPU :
128- skip_no_tpu = pytest .mark .skip (reason = "Skipped: requires TPU hardware, none detected" )
129-
130- if not _HAS_GPU :
131- skip_no_gpu = pytest .mark .skip (reason = "Skipped: requires GPU hardware, none detected" )
132-
133148 if not _has_tpu_backend_support ():
134149 skip_no_tpu_backend = pytest .mark .skip (
135150 reason = (
@@ -139,20 +154,8 @@ def pytest_collection_modifyitems(config, items):
139154 )
140155
141156 for item in items :
142- # Iterate thru the markers of every test.
143157 cur_test_markers = {m .name for m in item .iter_markers ()}
144158
145- # Hardware skip retains skip semantics.
146- if skip_no_tpu and "tpu_only" in cur_test_markers :
147- item .add_marker (skip_no_tpu )
148- remaining .append (item )
149- continue
150-
151- if skip_no_gpu and "gpu_only" in cur_test_markers :
152- item .add_marker (skip_no_gpu )
153- remaining .append (item )
154- continue
155-
156159 if skip_no_tpu_backend and "tpu_backend" in cur_test_markers :
157160 item .add_marker (skip_no_tpu_backend )
158161 remaining .append (item )
@@ -177,12 +180,73 @@ def pytest_collection_modifyitems(config, items):
177180
178181
179182def pytest_configure (config ):
183+ """Registers custom pytest markers dynamically."""
180184 for m in [
181185 "gpu_only: tests that require GPU hardware" ,
182186 "tpu_only: tests that require TPU hardware" ,
187+ "cpu_only: tests that require CPU-only environment (skipped on active accelerator hardware)" ,
183188 "tpu_backend: tests that require a TPU-enabled JAX install (TPU PJRT plugin), but not TPU hardware" ,
184189 "external_serving: JetStream / serving / decode server components" ,
185190 "external_training: goodput integrations" ,
186191 "decoupled: marked on tests that are not skipped due to GCP deps, when DECOUPLE_GCLOUD=TRUE" ,
192+ "skip_on_tpu7x: skip test if running on TPU7x platform" ,
187193 ]:
188194 config .addinivalue_line ("markers" , m )
195+
196+
197+ def _get_system_hardware_platform () -> str :
198+ """Determines the system hardware platform strictly from environment variables without JAX init."""
199+ # 1. Check JAX_PLATFORMS env var
200+ jax_platforms = os .getenv ("JAX_PLATFORMS" , "" ).lower ()
201+ if "tpu" in jax_platforms :
202+ return "tpu"
203+ if "cuda" in jax_platforms or "gpu" in jax_platforms :
204+ return "gpu"
205+
206+ # 2. Check active CUDA visible devices
207+ if os .getenv ("CUDA_VISIBLE_DEVICES" ) is not None :
208+ return "gpu"
209+
210+ # 3. Check TPU runtime variables
211+ if os .getenv ("TPU_NAME" ) is not None or os .getenv ("TPU_CHIPS" ) is not None :
212+ return "tpu"
213+
214+ # Default to CPU
215+ return "cpu"
216+
217+
218+ @pytest .fixture (autouse = True )
219+ def handle_skip_on_tpu7x (request ):
220+ """Dynamically skip tests marked with skip_on_tpu7x if running on TPU7x."""
221+ if request .node .get_closest_marker ("skip_on_tpu7x" ):
222+ if _get_system_hardware_platform () == "tpu" :
223+ try :
224+ is_tpu7x = any ("TPU7x" in d .device_kind for d in jax .devices ())
225+ except Exception : # pylint: disable=broad-exception-caught
226+ is_tpu7x = False
227+ if is_tpu7x :
228+ pytest .skip ("AOT tests do not support TPU7x platform" )
229+
230+
231+ @pytest .fixture (autouse = True )
232+ def handle_cpu_only (request ):
233+ """Dynamically skip cpu_only tests on TPU or GPU hardware."""
234+ if request .node .get_closest_marker ("cpu_only" ):
235+ if _get_system_hardware_platform () in ("tpu" , "gpu" ):
236+ pytest .skip ("Skipped: cpu_only test bypassed on hardware accelerator testbeds" )
237+
238+
239+ @pytest .fixture (autouse = True )
240+ def handle_tpu_only (request ):
241+ """Dynamically skip tpu_only tests if running on non-TPU hardware."""
242+ if request .node .get_closest_marker ("tpu_only" ):
243+ if _get_system_hardware_platform () != "tpu" :
244+ pytest .skip ("Skipped: requires TPU hardware, none detected" )
245+
246+
247+ @pytest .fixture (autouse = True )
248+ def handle_gpu_only (request ):
249+ """Dynamically skip gpu_only tests if running on non-GPU hardware."""
250+ if request .node .get_closest_marker ("gpu_only" ):
251+ if _get_system_hardware_platform () != "gpu" :
252+ pytest .skip ("Skipped: requires GPU hardware, none detected" )
0 commit comments