Skip to content

Commit 8a382dc

Browse files
Merge pull request #3888 from AI-Hypercomputer:fix-parameterized-marker-ordering
PiperOrigin-RevId: 915544286
2 parents fb0fdce + f805d4d commit 8a382dc

2 files changed

Lines changed: 102 additions & 0 deletions

File tree

tests/conftest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,48 @@
2424
import jax
2525
import importlib.util
2626

27+
# --- Monkeypatch for absl.testing.parameterized ---
28+
# Context: Decorating a test method with @parameterized.named_parameters returns a custom
29+
# iterable container (_ParameterizedTestIter) instead of a standard function object.
30+
# Problem: When pytest markers are applied above @parameterized in the decorator stack:
31+
#
32+
# @pytest.mark.cpu_only
33+
# @parameterized.named_parameters(...)
34+
# def test_foo(self, ...):
35+
#
36+
# pytest attaches the marker attributes exclusively to the outer iterable container object.
37+
# During class initialization, the test metaclass unwraps the base function to generate
38+
# individual test methods, omitting the outer container entirely. Consequently, marker
39+
# attributes attached to the outer container are dropped and lost before pytest collection.
40+
# Solution: Intercept _ParameterizedTestIter.__iter__ to dynamically propagate any discovered
41+
# pytestmark attributes from the outer container object down to all generated test methods.
42+
from absl.testing import parameterized
43+
44+
try:
45+
# pylint: disable=protected-access
46+
_orig_iter = parameterized._ParameterizedTestIter.__iter__
47+
48+
def _custom_iter(self):
49+
"""Custom iterator propagating outer pytestmark attributes to generated test methods."""
50+
outer_marks = getattr(self, "pytestmark", None)
51+
if outer_marks is None:
52+
yield from _orig_iter(self)
53+
else:
54+
if not isinstance(outer_marks, list):
55+
outer_marks = [outer_marks]
56+
57+
for func in _orig_iter(self):
58+
existing_marks = getattr(func, "pytestmark", [])
59+
if not isinstance(existing_marks, list):
60+
existing_marks = [existing_marks]
61+
func.pytestmark = existing_marks + outer_marks
62+
yield func
63+
64+
parameterized._ParameterizedTestIter.__iter__ = _custom_iter
65+
# pylint: enable=protected-access
66+
except AttributeError:
67+
pass
68+
2769
try:
2870
_HAS_TPU = any(d.platform == "tpu" for d in jax.devices())
2971
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests validating pytest marker propagation through decorator stacks."""
16+
17+
import functools
18+
import unittest
19+
20+
from absl.testing import parameterized
21+
import jax
22+
import pytest
23+
24+
25+
def dummy_decorator(func):
26+
"""Standard transparent wrapper decorator preserving function metadata."""
27+
28+
@functools.wraps(func)
29+
def wrapper(*args, **kwargs):
30+
return func(*args, **kwargs)
31+
32+
return wrapper
33+
34+
35+
class MarkerPropagationTest(parameterized.TestCase):
36+
"""Validates that pytest markers propagate correctly through decorator stacks."""
37+
38+
@pytest.mark.cpu_only
39+
@parameterized.named_parameters(
40+
{"testcase_name": "default", "unused": None},
41+
)
42+
def test_parameterized_cpu_only_marker_propagation(self, unused):
43+
"""Verifies cpu_only marker above @parameterized propagates to generated methods."""
44+
has_tpu = any(d.platform == "tpu" for d in jax.devices())
45+
has_gpu = any(d.platform == "gpu" for d in jax.devices())
46+
assert not has_tpu, "cpu_only parameterized test accidentally executed on TPU hardware"
47+
assert not has_gpu, "cpu_only parameterized test accidentally executed on GPU hardware"
48+
49+
@pytest.mark.cpu_only
50+
@dummy_decorator
51+
def test_standard_decorator_cpu_only_marker_propagation(self):
52+
"""Verifies cpu_only marker above standard decorators propagates correctly."""
53+
has_tpu = any(d.platform == "tpu" for d in jax.devices())
54+
has_gpu = any(d.platform == "gpu" for d in jax.devices())
55+
assert not has_tpu, "cpu_only standard decorated test accidentally executed on TPU hardware"
56+
assert not has_gpu, "cpu_only standard decorated test accidentally executed on GPU hardware"
57+
58+
59+
if __name__ == "__main__":
60+
unittest.main()

0 commit comments

Comments
 (0)