Skip to content

Commit 932223d

Browse files
Copilotnjzjz
andcommitted
Changes before error encountered
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 610c6fa commit 932223d

15 files changed

Lines changed: 625 additions & 18 deletions

File tree

deepmd/tf/entrypoints/train.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,29 @@ def _do_work(
229229
# setup data modifier
230230
modifier = get_modifier(jdata["model"].get("modifier", None))
231231

232+
# extract stat_file from training parameters
233+
stat_file_path = None
234+
if not is_compress:
235+
stat_file_raw = jdata["training"].get("stat_file", None)
236+
if stat_file_raw is not None and run_opt.is_chief:
237+
from pathlib import (
238+
Path,
239+
)
240+
241+
from deepmd.utils.path import (
242+
DPPath,
243+
)
244+
245+
if not Path(stat_file_raw).exists():
246+
if stat_file_raw.endswith((".h5", ".hdf5")):
247+
import h5py
248+
249+
with h5py.File(stat_file_raw, "w") as f:
250+
pass
251+
else:
252+
Path(stat_file_raw).mkdir()
253+
stat_file_path = DPPath(stat_file_raw, "a")
254+
232255
# decouple the training data from the model compress process
233256
train_data = None
234257
valid_data = None
@@ -261,7 +284,12 @@ def _do_work(
261284
origin_type_map = get_data(
262285
jdata["training"]["training_data"], rcut, None, modifier
263286
).get_type_map()
264-
model.build(train_data, stop_batch, origin_type_map=origin_type_map)
287+
model.build(
288+
train_data,
289+
stop_batch,
290+
origin_type_map=origin_type_map,
291+
stat_file_path=stat_file_path,
292+
)
265293

266294
if not is_compress:
267295
# train the model with the provided systems in a cyclic way

deepmd/tf/model/dos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_numb_aparam(self) -> int:
9090
"""Get the number of atomic parameters."""
9191
return self.numb_aparam
9292

93-
def data_stat(self, data) -> None:
93+
def data_stat(self, data, stat_file_path=None) -> None:
9494
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
9595
m_all_stat = merge_sys_stat(all_stat)
9696
self._compute_input_stat(

deepmd/tf/model/ener.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,15 @@ def get_numb_aparam(self) -> int:
135135
"""Get the number of atomic parameters."""
136136
return self.numb_aparam
137137

138-
def data_stat(self, data) -> None:
138+
def data_stat(self, data, stat_file_path=None) -> None:
139139
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
140140
m_all_stat = merge_sys_stat(all_stat)
141141
self._compute_input_stat(
142142
m_all_stat, protection=self.data_stat_protect, mixed_type=data.mixed_type
143143
)
144-
self._compute_output_stat(all_stat, mixed_type=data.mixed_type)
144+
self._compute_output_stat(
145+
all_stat, mixed_type=data.mixed_type, stat_file_path=stat_file_path
146+
)
145147
# self.bias_atom_e = data.compute_energy_shift(self.rcond)
146148

147149
def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> None:
@@ -167,11 +169,37 @@ def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> No
167169
)
168170
self.fitting.compute_input_stats(all_stat, protection=protection)
169171

170-
def _compute_output_stat(self, all_stat, mixed_type=False) -> None:
171-
if mixed_type:
172-
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
172+
def _compute_output_stat(
173+
self, all_stat, mixed_type=False, stat_file_path=None
174+
) -> None:
175+
if stat_file_path is not None:
176+
# Use the new stat functionality with file save/load
177+
from deepmd.tf.utils.stat import (
178+
compute_output_stats,
179+
)
180+
181+
# Merge system stats for compatibility
182+
m_all_stat = merge_sys_stat(all_stat)
183+
184+
bias_out, std_out = compute_output_stats(
185+
m_all_stat,
186+
self.ntypes,
187+
keys=["energy"],
188+
stat_file_path=stat_file_path,
189+
rcond=getattr(self, "rcond", None),
190+
mixed_type=mixed_type,
191+
)
192+
193+
# Set the computed bias and std in the fitting object
194+
if "energy" in bias_out:
195+
self.fitting.bias_atom_e = bias_out["energy"]
196+
173197
else:
174-
self.fitting.compute_output_stats(all_stat)
198+
# Use the original computation method
199+
if mixed_type:
200+
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
201+
else:
202+
self.fitting.compute_output_stats(all_stat)
175203

176204
def build(
177205
self,

deepmd/tf/model/frozen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_rcut(self):
200200
def get_ntypes(self) -> int:
201201
return self.model.get_ntypes()
202202

203-
def data_stat(self, data) -> None:
203+
def data_stat(self, data, stat_file_path=None) -> None:
204204
pass
205205

206206
def init_variables(

deepmd/tf/model/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_ntypes(self) -> int:
9090
raise ValueError("Models have different ntypes")
9191
return self.models[0].get_ntypes()
9292

93-
def data_stat(self, data) -> None:
93+
def data_stat(self, data, stat_file_path=None) -> None:
9494
for model in self.models:
9595
model.data_stat(data)
9696

deepmd/tf/model/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def get_ntypes(self) -> int:
458458
"""Get the number of types."""
459459

460460
@abstractmethod
461-
def data_stat(self, data: dict):
461+
def data_stat(self, data: dict, stat_file_path=None):
462462
"""Data staticis."""
463463

464464
def get_feed_dict(

deepmd/tf/model/pairwise_dprc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def get_rcut(self):
317317
def get_ntypes(self) -> int:
318318
return self.ntypes
319319

320-
def data_stat(self, data) -> None:
320+
def data_stat(self, data, stat_file_path=None) -> None:
321321
self.qm_model.data_stat(data)
322322
self.qmmm_model.data_stat(data)
323323

deepmd/tf/model/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def get_sel_type(self):
8282
def get_out_size(self):
8383
return self.fitting.get_out_size()
8484

85-
def data_stat(self, data) -> None:
85+
def data_stat(self, data, stat_file_path=None) -> None:
8686
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
8787
m_all_stat = merge_sys_stat(all_stat)
8888
self._compute_input_stat(m_all_stat, protection=self.data_stat_protect)

deepmd/tf/train/trainer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,14 @@ def get_lr_and_coef(lr_param):
170170
self.ckpt_meta = None
171171
self.model_type = None
172172

173-
def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> None:
173+
def build(
174+
self,
175+
data=None,
176+
stop_batch=0,
177+
origin_type_map=None,
178+
suffix="",
179+
stat_file_path=None,
180+
) -> None:
174181
self.ntypes = self.model.get_ntypes()
175182
self.stop_batch = stop_batch
176183

@@ -209,7 +216,7 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> Non
209216
# self.saver.restore (in self._init_session) will restore avg and std variables, so data_stat is useless
210217
# init_from_frz_model will restore data_stat variables in `init_variables` method
211218
log.info("data stating... (this step may take long time)")
212-
self.model.data_stat(data)
219+
self.model.data_stat(data, stat_file_path=stat_file_path)
213220

214221
# config the init_frz_model command
215222
if self.run_opt.init_mode == "init_from_frz_model":

deepmd/tf/utils/stat.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import logging
3+
from typing import (
4+
Optional,
5+
)
6+
7+
import numpy as np
8+
9+
from deepmd.utils.path import (
10+
DPPath,
11+
)
12+
13+
log = logging.getLogger(__name__)
14+
15+
16+
def _restore_from_file(
17+
stat_file_path: DPPath,
18+
keys: list[str] = ["energy"],
19+
) -> Optional[tuple[dict, dict]]:
20+
"""Restore bias and std from stat file.
21+
22+
Parameters
23+
----------
24+
stat_file_path : DPPath
25+
Path to the stat file directory/file
26+
keys : list[str]
27+
Keys to restore statistics for
28+
29+
Returns
30+
-------
31+
ret_bias : dict or None
32+
Bias values for each key
33+
ret_std : dict or None
34+
Standard deviation values for each key
35+
"""
36+
if stat_file_path is None:
37+
return None, None
38+
stat_files = [stat_file_path / f"bias_atom_{kk}" for kk in keys]
39+
if all(not (ii.is_file()) for ii in stat_files):
40+
return None, None
41+
stat_files = [stat_file_path / f"std_atom_{kk}" for kk in keys]
42+
if all(not (ii.is_file()) for ii in stat_files):
43+
return None, None
44+
45+
ret_bias = {}
46+
ret_std = {}
47+
for kk in keys:
48+
fp = stat_file_path / f"bias_atom_{kk}"
49+
# only read the key that exists
50+
if fp.is_file():
51+
ret_bias[kk] = fp.load_numpy()
52+
for kk in keys:
53+
fp = stat_file_path / f"std_atom_{kk}"
54+
# only read the key that exists
55+
if fp.is_file():
56+
ret_std[kk] = fp.load_numpy()
57+
return ret_bias, ret_std
58+
59+
60+
def _save_to_file(
61+
stat_file_path: DPPath,
62+
bias_out: dict,
63+
std_out: dict,
64+
) -> None:
65+
"""Save bias and std to stat file.
66+
67+
Parameters
68+
----------
69+
stat_file_path : DPPath
70+
Path to the stat file directory/file
71+
bias_out : dict
72+
Bias values for each key
73+
std_out : dict
74+
Standard deviation values for each key
75+
"""
76+
assert stat_file_path is not None
77+
stat_file_path.mkdir(exist_ok=True, parents=True)
78+
for kk, vv in bias_out.items():
79+
fp = stat_file_path / f"bias_atom_{kk}"
80+
fp.save_numpy(vv)
81+
for kk, vv in std_out.items():
82+
fp = stat_file_path / f"std_atom_{kk}"
83+
fp.save_numpy(vv)
84+
85+
86+
def compute_output_stats(
87+
all_stat: dict,
88+
ntypes: int,
89+
keys: list[str] = ["energy"],
90+
stat_file_path: Optional[DPPath] = None,
91+
rcond: Optional[float] = None,
92+
mixed_type: bool = False,
93+
) -> tuple[dict, dict]:
94+
"""Compute output statistics for TensorFlow models.
95+
96+
This is a simplified version of the PyTorch compute_output_stats function
97+
adapted for TensorFlow models.
98+
99+
Parameters
100+
----------
101+
all_stat : dict
102+
Dictionary containing statistical data
103+
ntypes : int
104+
Number of atom types
105+
keys : list[str]
106+
Keys to compute statistics for
107+
stat_file_path : DPPath, optional
108+
Path to save/load statistics
109+
rcond : float, optional
110+
Condition number for regression
111+
mixed_type : bool
112+
Whether mixed type format is used
113+
114+
Returns
115+
-------
116+
bias_out : dict
117+
Computed bias values
118+
std_out : dict
119+
Computed standard deviation values
120+
"""
121+
# Try to restore from file first
122+
bias_out, std_out = _restore_from_file(stat_file_path, keys)
123+
124+
if bias_out is not None and std_out is not None:
125+
log.info("Successfully restored statistics from stat file")
126+
return bias_out, std_out
127+
128+
# If restore failed, compute from data
129+
log.info("Computing statistics from training data")
130+
131+
from deepmd.utils.out_stat import (
132+
compute_stats_from_redu,
133+
)
134+
135+
bias_out = {}
136+
std_out = {}
137+
138+
for key in keys:
139+
if key in all_stat:
140+
# Get energy and natoms data
141+
energy_data = np.concatenate(all_stat[key])
142+
natoms_data = np.concatenate(all_stat["natoms_vec"])[
143+
:, 2:
144+
] # Skip first 2 elements
145+
146+
# Compute statistics using existing utility
147+
bias, std = compute_stats_from_redu(
148+
energy_data.reshape(-1, 1), # Reshape to column vector
149+
natoms_data,
150+
rcond=rcond,
151+
)
152+
153+
bias_out[key] = bias.reshape(-1) # Flatten to 1D
154+
std_out[key] = std.reshape(-1) # Flatten to 1D
155+
156+
log.info(
157+
f"Statistics computed for {key}: bias shape {bias_out[key].shape}, std shape {std_out[key].shape}"
158+
)
159+
160+
# Save to file if path provided
161+
if stat_file_path is not None and bias_out:
162+
_save_to_file(stat_file_path, bias_out, std_out)
163+
log.info("Statistics saved to stat file")
164+
165+
return bias_out, std_out

0 commit comments

Comments
 (0)