Skip to content

Commit baab3e8

Browse files
authored
test(consistent): replace Cartesian product with curated matrices for descriptor tests (#5378)
Problem - Several descriptor consistency suites use full Cartesian products over many feature toggles. - That creates large default PR-CI matrices whose cost is dominated by combinatorial expansion rather than a few isolated slow cases. - In this group alone, the main suite case counts go from 48 -> 7 for `test_se_e2_a.py`, 64 -> 9 for `test_dpa1.py`, and 512 -> 12 for `test_dpa2.py`. Change - Add `parameterized_cases(*cases)` in `source/tests/consistent/common.py` for explicit curated matrices without changing existing `parameterized(*attrs)` semantics. - Replace Cartesian-product decorators with curated case lists in: - `source/tests/consistent/descriptor/test_se_e2_a.py` - `source/tests/consistent/descriptor/test_dpa1.py` - `source/tests/consistent/descriptor/test_dpa2.py` - Reuse the same curated matrices for descriptor API tests where applicable. - Keep existing skip logic unchanged; only the default generated case sets are reduced. Validation - `python -m pytest --collect-only source/tests/consistent/descriptor/test_se_e2_a.py source/tests/consistent/descriptor/test_dpa1.py source/tests/consistent/descriptor/test_dpa2.py` - collected 488 tests after the reduction - `python -m pytest source/tests/consistent/descriptor/test_se_e2_a.py source/tests/consistent/descriptor/test_dpa1.py source/tests/consistent/descriptor/test_dpa2.py -q` - `91 passed, 397 skipped` - `python3 -m py_compile` on the touched files Related - Closes #5372 Authored by OpenClaw (model: gpt-5.4) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a decorator to specify explicit test cases (no Cartesian-product expansion). * **Tests** * Refactored multiple test suites to use curated, explicit case sets for clearer, deterministic parametrization. * Added reusable case-construction helpers and baseline/curated case collections to simplify test definitions. * Improved deterministic test-class naming with robust sanitization and uniqueness handling. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 67a30f2 commit baab3e8

File tree

4 files changed

+306
-132
lines changed

4 files changed

+306
-132
lines changed

source/tests/consistent/common.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import itertools
44
import os
5+
import re
56
import sys
67
import unittest
78
from abc import (
@@ -75,7 +76,9 @@
7576
"INSTALLED_PT_EXPT",
7677
"INSTALLED_TF",
7778
"CommonTest",
78-
"CommonTest",
79+
"parameterize_func",
80+
"parameterized",
81+
"parameterized_cases",
7982
]
8083

8184
SKIP_FLAG = object()
@@ -670,6 +673,46 @@ def tearDown(self) -> None:
670673
clear_session()
671674

672675

676+
def _parameterized_with_cases(full_parameterized: list[tuple]) -> Callable:
677+
def decorator(base_class: type):
678+
class_module = sys.modules[base_class.__module__].__dict__
679+
used_names: set[str] = set()
680+
for pp in full_parameterized:
681+
682+
class TestClass(base_class):
683+
param: ClassVar = pp
684+
685+
# generate a safe name for the class
686+
parts = []
687+
for x in pp:
688+
s = str(x)
689+
# replace non-alnum with underscore, collapse multiple underscores
690+
s = re.sub(r"[^a-zA-Z0-9_]", "_", s)
691+
s = re.sub(r"_+", "_", s)
692+
# remove leading/trailing underscores
693+
s = s.strip("_")
694+
if s == "":
695+
s = "empty"
696+
parts.append(s)
697+
base_name = f"{base_class.__name__}_{'_'.join(parts)}"
698+
name = base_name
699+
suffix = 1
700+
while name in used_names or name in class_module:
701+
name = f"{base_name}_{suffix}"
702+
suffix += 1
703+
704+
TestClass.__name__ = name
705+
TestClass.__qualname__ = name
706+
TestClass.__module__ = base_class.__module__
707+
708+
used_names.add(name)
709+
class_module[name] = TestClass
710+
# make unittest module happy by ignoring the original one
711+
return object
712+
713+
return decorator
714+
715+
673716
def parameterized(*attrs: tuple, **subblock_attrs: tuple) -> Callable:
674717
"""Parameterized test.
675718
@@ -733,6 +776,27 @@ class TestClass(base_class):
733776
return decorator
734777

735778

779+
def parameterized_cases(*cases: tuple) -> Callable:
780+
"""Parameterized test with explicit case tuples.
781+
782+
This variant behaves like :func:`parameterized` but takes a curated list of
783+
case tuples directly instead of computing their Cartesian product.
784+
785+
Parameters
786+
----------
787+
*cases : tuple
788+
Explicit case tuples.
789+
790+
Returns
791+
-------
792+
object
793+
The decorator.
794+
"""
795+
if not cases:
796+
raise ValueError("parameterized_cases requires at least one case tuple")
797+
return _parameterized_with_cases(list(cases))
798+
799+
736800
def parameterize_func(
737801
func: Callable,
738802
param_dict_list: dict[str, tuple],

source/tests/consistent/descriptor/test_dpa1.py

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
INSTALLED_PT_EXPT,
2323
INSTALLED_TF,
2424
CommonTest,
25-
parameterized,
25+
parameterized_cases,
2626
)
2727
from .common import (
2828
DescriptorAPITest,
@@ -57,28 +57,76 @@
5757
descrpt_se_atten_args,
5858
)
5959

60+
DPA1_CASE_FIELDS = (
61+
"tebd_dim",
62+
"tebd_input_mode",
63+
"resnet_dt",
64+
"type_one_side",
65+
"attn",
66+
"attn_layer",
67+
"attn_dotr",
68+
"excluded_types",
69+
"env_protection",
70+
"set_davg_zero",
71+
"scaling_factor",
72+
"normalize",
73+
"temperature",
74+
"ln_eps",
75+
"smooth_type_embedding",
76+
"concat_output_tebd",
77+
"precision",
78+
"use_econf_tebd",
79+
"use_tebd_bias",
80+
)
81+
6082

61-
@parameterized(
62-
(4,), # tebd_dim
63-
("concat", "strip"), # tebd_input_mode
64-
(True,), # resnet_dt
65-
(True,), # type_one_side
66-
(20,), # attn
67-
(0, 2), # attn_layer
68-
(True,), # attn_dotr
69-
([], [[0, 1]]), # excluded_types
70-
(0.0,), # env_protection
71-
(True, False), # set_davg_zero
72-
(1.0,), # scaling_factor
73-
(True,), # normalize
74-
(None, 1.0), # temperature
75-
(1e-5,), # ln_eps
76-
(True,), # smooth_type_embedding
77-
(True,), # concat_output_tebd
78-
("float64",), # precision
79-
(True, False), # use_econf_tebd
80-
(False,), # use_tebd_bias
83+
DPA1_BASELINE_CASE = {
84+
"tebd_dim": 4,
85+
"tebd_input_mode": "concat",
86+
"resnet_dt": True,
87+
"type_one_side": True,
88+
"attn": 20,
89+
"attn_layer": 2,
90+
"attn_dotr": True,
91+
"excluded_types": [],
92+
"env_protection": 0.0,
93+
"set_davg_zero": True,
94+
"scaling_factor": 1.0,
95+
"normalize": True,
96+
"temperature": 1.0,
97+
"ln_eps": 1e-5,
98+
"smooth_type_embedding": True,
99+
"concat_output_tebd": True,
100+
"precision": "float64",
101+
"use_econf_tebd": False,
102+
"use_tebd_bias": False,
103+
}
104+
105+
106+
def dpa1_case(**overrides: Any) -> tuple:
107+
case = DPA1_BASELINE_CASE | overrides
108+
return tuple(case[field] for field in DPA1_CASE_FIELDS)
109+
110+
111+
DPA1_CURATED_CASES = (
112+
# Baseline coverage.
113+
dpa1_case(),
114+
# Alternate tebd input plumbing.
115+
dpa1_case(tebd_input_mode="strip"),
116+
# High-risk descriptor toggles.
117+
dpa1_case(excluded_types=[[0, 1]]),
118+
dpa1_case(set_davg_zero=False),
119+
dpa1_case(normalize=False),
120+
# Attention edge cases: disabled temperature path vs zero-layer path.
121+
dpa1_case(temperature=None),
122+
dpa1_case(attn_layer=0, temperature=None),
123+
# econf-specific path with both tebd input modes.
124+
dpa1_case(use_econf_tebd=True),
125+
dpa1_case(tebd_input_mode="strip", use_econf_tebd=True),
81126
)
127+
128+
129+
@parameterized_cases(*DPA1_CURATED_CASES)
82130
class TestDPA1(CommonTest, DescriptorTest, unittest.TestCase):
83131
@property
84132
def data(self) -> dict:
@@ -556,27 +604,7 @@ def atol(self) -> float:
556604
raise ValueError(f"Unknown precision: {precision}")
557605

558606

559-
@parameterized(
560-
(4,), # tebd_dim
561-
("concat", "strip"), # tebd_input_mode
562-
(True,), # resnet_dt
563-
(True,), # type_one_side
564-
(20,), # attn
565-
(0, 2), # attn_layer
566-
(True,), # attn_dotr
567-
([], [[0, 1]]), # excluded_types
568-
(0.0,), # env_protection
569-
(True, False), # set_davg_zero
570-
(1.0,), # scaling_factor
571-
(True,), # normalize
572-
(None, 1.0), # temperature
573-
(1e-5,), # ln_eps
574-
(True,), # smooth_type_embedding
575-
(True,), # concat_output_tebd
576-
("float64",), # precision
577-
(True, False), # use_econf_tebd
578-
(False,), # use_tebd_bias
579-
)
607+
@parameterized_cases(*DPA1_CURATED_CASES)
580608
class TestDPA1DescriptorAPI(DescriptorAPITest, unittest.TestCase):
581609
"""Test consistency of BaseDescriptor API methods across backends."""
582610

0 commit comments

Comments
 (0)