Skip to content

Commit 98417ad

Browse files
committed
fix uts
1 parent 51bb8ea commit 98417ad

4 files changed

Lines changed: 21 additions & 0 deletions

File tree

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ def __setitem__(self, key, value) -> None:
292292
self.case_embd = value
293293
elif key in ["scale"]:
294294
self.scale = value
295+
elif key in ["default_fparam_tensor"]:
296+
self.default_fparam_tensor = value
295297
else:
296298
raise KeyError(key)
297299

@@ -310,6 +312,8 @@ def __getitem__(self, key):
310312
return self.case_embd
311313
elif key in ["scale"]:
312314
return self.scale
315+
elif key in ["default_fparam_tensor"]:
316+
return self.default_fparam_tensor
313317
else:
314318
raise KeyError(key)
315319

deepmd/jax/fitting/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
3535
"fparam_inv_std",
3636
"aparam_avg",
3737
"aparam_inv_std",
38+
"default_fparam_tensor",
3839
}:
3940
value = to_jax_array(value)
4041
if value is not None:

deepmd/pt/model/task/fitting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ def __setitem__(self, key, value) -> None:
495495
self.case_embd = value
496496
elif key in ["scale"]:
497497
self.scale = value
498+
elif key in ["default_fparam_tensor"]:
499+
self.default_fparam_tensor = value
498500
else:
499501
raise KeyError(key)
500502

@@ -513,6 +515,8 @@ def __getitem__(self, key):
513515
return self.case_embd
514516
elif key in ["scale"]:
515517
return self.scale
518+
elif key in ["default_fparam_tensor"]:
519+
return self.default_fparam_tensor
516520
else:
517521
raise KeyError(key)
518522

source/tests/consistent/fitting/test_ener.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,18 @@ def skip_pd(self) -> bool:
138138
# so skip this in CI test
139139
return not INSTALLED_PD or precision == "bfloat16" or default_fparam is not None
140140

141+
@property
142+
def skip_tf(self) -> bool:
143+
(
144+
resnet_dt,
145+
precision,
146+
mixed_types,
147+
(numb_fparam, default_fparam),
148+
(numb_aparam, use_aparam_as_mask),
149+
atom_ener,
150+
) = self.param
151+
return not INSTALLED_TF or default_fparam is not None
152+
141153
tf_class = EnerFittingTF
142154
dp_class = EnerFittingDP
143155
pt_class = EnerFittingPT

0 commit comments

Comments
 (0)