Skip to content

Commit e857454

Browse files
Copilotnjzjz
andcommitted
refactor: merge TestDipoleSelType with TestDipole using parameterized sel_type
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent b9e5bed commit e857454

1 file changed

Lines changed: 13 additions & 206 deletions

File tree

source/tests/consistent/fitting/test_dipole.py

Lines changed: 13 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -61,176 +61,9 @@
6161
(True, False), # resnet_dt
6262
("float64", "float32"), # precision
6363
(True, False), # mixed_types
64+
([], [0, 1]), # sel_type
6465
)
6566
class TestDipole(CommonTest, DipoleFittingTest, unittest.TestCase):
66-
@property
67-
def data(self) -> dict:
68-
(
69-
resnet_dt,
70-
precision,
71-
mixed_types,
72-
) = self.param
73-
return {
74-
"neuron": [5, 5, 5],
75-
"resnet_dt": resnet_dt,
76-
"precision": precision,
77-
"seed": 20240217,
78-
}
79-
80-
@property
81-
def skip_pt(self) -> bool:
82-
(
83-
resnet_dt,
84-
precision,
85-
mixed_types,
86-
) = self.param
87-
return CommonTest.skip_pt
88-
89-
tf_class = DipoleFittingTF
90-
dp_class = DipoleFittingDP
91-
pt_class = DipoleFittingPT
92-
jax_class = DipoleFittingJAX
93-
array_api_strict_class = DipoleFittingArrayAPIStrict
94-
args = fitting_dipole()
95-
skip_jax = not INSTALLED_JAX
96-
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
97-
98-
def setUp(self) -> None:
99-
CommonTest.setUp(self)
100-
101-
self.ntypes = 2
102-
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
103-
self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION)
104-
self.gr = np.ones((1, 6, 30, 3), dtype=GLOBAL_NP_FLOAT_PRECISION)
105-
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
106-
# inconsistent if not sorted
107-
self.atype.sort()
108-
109-
@property
110-
def additional_data(self) -> dict:
111-
(
112-
resnet_dt,
113-
precision,
114-
mixed_types,
115-
) = self.param
116-
return {
117-
"ntypes": self.ntypes,
118-
"dim_descrpt": self.inputs.shape[-1],
119-
"mixed_types": mixed_types,
120-
"embedding_width": 30,
121-
}
122-
123-
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
124-
(
125-
resnet_dt,
126-
precision,
127-
mixed_types,
128-
) = self.param
129-
return self.build_tf_fitting(
130-
obj,
131-
self.inputs.ravel(),
132-
self.gr,
133-
self.natoms,
134-
self.atype,
135-
None,
136-
suffix,
137-
)
138-
139-
def eval_pt(self, pt_obj: Any) -> Any:
140-
(
141-
resnet_dt,
142-
precision,
143-
mixed_types,
144-
) = self.param
145-
return (
146-
pt_obj(
147-
torch.from_numpy(self.inputs).to(device=PT_DEVICE),
148-
torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE),
149-
torch.from_numpy(self.gr).to(device=PT_DEVICE),
150-
None,
151-
)["dipole"]
152-
.detach()
153-
.cpu()
154-
.numpy()
155-
)
156-
157-
def eval_dp(self, dp_obj: Any) -> Any:
158-
(
159-
resnet_dt,
160-
precision,
161-
mixed_types,
162-
) = self.param
163-
return dp_obj(
164-
self.inputs,
165-
self.atype.reshape(1, -1),
166-
self.gr,
167-
None,
168-
)["dipole"]
169-
170-
def eval_jax(self, jax_obj: Any) -> Any:
171-
return np.asarray(
172-
jax_obj(
173-
jnp.asarray(self.inputs),
174-
jnp.asarray(self.atype.reshape(1, -1)),
175-
jnp.asarray(self.gr),
176-
None,
177-
)["dipole"]
178-
)
179-
180-
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
181-
return to_numpy_array(
182-
array_api_strict_obj(
183-
array_api_strict.asarray(self.inputs),
184-
array_api_strict.asarray(self.atype.reshape(1, -1)),
185-
array_api_strict.asarray(self.gr),
186-
None,
187-
)["dipole"]
188-
)
189-
190-
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
191-
if backend == self.RefBackend.TF:
192-
# shape is not same
193-
ret = ret[0].reshape(-1, self.natoms[0], 1)
194-
return (ret,)
195-
196-
@property
197-
def rtol(self) -> float:
198-
"""Relative tolerance for comparing the return value."""
199-
(
200-
resnet_dt,
201-
precision,
202-
mixed_types,
203-
) = self.param
204-
if precision == "float64":
205-
return 1e-10
206-
elif precision == "float32":
207-
return 1e-4
208-
else:
209-
raise ValueError(f"Unknown precision: {precision}")
210-
211-
@property
212-
def atol(self) -> float:
213-
"""Absolute tolerance for comparing the return value."""
214-
(
215-
resnet_dt,
216-
precision,
217-
mixed_types,
218-
) = self.param
219-
if precision == "float64":
220-
return 1e-10
221-
elif precision == "float32":
222-
return 1e-4
223-
else:
224-
raise ValueError(f"Unknown precision: {precision}")
225-
226-
227-
@parameterized(
228-
(True, False), # resnet_dt
229-
("float64", "float32"), # precision
230-
(True, False), # mixed_types
231-
([0, 1],), # sel_type - only test with all types selected for consistency
232-
)
233-
class TestDipoleSelType(CommonTest, DipoleFittingTest, unittest.TestCase):
23467
@property
23568
def data(self) -> dict:
23669
(
@@ -239,13 +72,16 @@ def data(self) -> dict:
23972
mixed_types,
24073
sel_type,
24174
) = self.param
242-
return {
75+
data = {
24376
"neuron": [5, 5, 5],
24477
"resnet_dt": resnet_dt,
24578
"precision": precision,
24679
"seed": 20240217,
247-
"sel_type": sel_type, # For TF backend
24880
}
81+
# Only add sel_type if it's not empty (for TF backend compatibility)
82+
if sel_type:
83+
data["sel_type"] = sel_type
84+
return data
24985

25086
@property
25187
def skip_pt(self) -> bool:
@@ -285,16 +121,18 @@ def additional_data(self) -> dict:
285121
mixed_types,
286122
sel_type,
287123
) = self.param
288-
# For DP/PT backends, use exclude_types instead of sel_type
289-
all_types = list(range(self.ntypes))
290-
exclude_types = [t for t in all_types if t not in sel_type]
291-
return {
124+
additional = {
292125
"ntypes": self.ntypes,
293126
"dim_descrpt": self.inputs.shape[-1],
294127
"mixed_types": mixed_types,
295128
"embedding_width": 30,
296-
"exclude_types": exclude_types, # For DP/PT backends
297129
}
130+
# For DP/PT backends, use exclude_types instead of sel_type
131+
if sel_type:
132+
all_types = list(range(self.ntypes))
133+
exclude_types = [t for t in all_types if t not in sel_type]
134+
additional["exclude_types"] = exclude_types
135+
return additional
298136

299137
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
300138
(
@@ -404,37 +242,6 @@ def atol(self) -> float:
404242
else:
405243
raise ValueError(f"Unknown precision: {precision}")
406244

407-
def test_sel_type_behavior(self):
408-
"""Test that sel_type parameter works correctly across backends."""
409-
(
410-
resnet_dt,
411-
precision,
412-
mixed_types,
413-
sel_type,
414-
) = self.param
415-
416-
# Test TF backend if available
417-
if INSTALLED_TF:
418-
tf_obj = self.tf_class(**{**self.data, **self.additional_data})
419-
420-
# Verify that only selected types have fitting nets
421-
if hasattr(tf_obj, "sel_type"):
422-
self.assertEqual(set(tf_obj.sel_type), set(sel_type))
423-
if hasattr(tf_obj, "sel_mask"):
424-
expected_mask = np.array([i in sel_type for i in range(self.ntypes)])
425-
np.testing.assert_array_equal(tf_obj.sel_mask, expected_mask)
426-
427-
# Test DP backend
428-
all_types = list(range(self.ntypes))
429-
exclude_types = [t for t in all_types if t not in sel_type]
430-
dp_data = {**self.data}
431-
dp_data.pop("sel_type", None) # Remove sel_type for DP backend
432-
dp_obj = self.dp_class(**{**dp_data, **self.additional_data})
433-
434-
# Verify that exclude_types is set correctly
435-
if hasattr(dp_obj, "exclude_types"):
436-
self.assertEqual(set(dp_obj.exclude_types), set(exclude_types))
437-
438245

439246
class TestDipoleSelTypeBehavior(unittest.TestCase):
440247
"""Test sel_type behavior specifically, without cross-backend consistency."""

0 commit comments

Comments
 (0)