Skip to content

Commit 432be58

Browse files
committed
add use_cache as modifier var, so that the user can choose whether to save the data modification before training or to perform modification on-the-fly.
1 parent 8fd8f65 commit 432be58

3 files changed

Lines changed: 40 additions & 20 deletions

File tree

deepmd/pt/modifier/base_modifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828

2929

3030
class BaseModifier(torch.nn.Module, make_base_modifier()):
31-
def __init__(self) -> None:
31+
def __init__(self, use_cache: bool = True) -> None:
3232
"""Construct a base modifier for data modification tasks."""
3333
torch.nn.Module.__init__(self)
3434
self.modifier_type = "base"
3535
self.jitable = True
3636

37+
self.use_cache = use_cache
38+
3739
def serialize(self) -> dict:
3840
"""Serialize the modifier.
3941

deepmd/utils/data.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,11 @@ def __init__(
140140
# The prefix sum stores the range of indices contained in each directory, which is needed by get_item method
141141
self.prefix_sum = np.cumsum(frames_list).tolist()
142142

143-
self.apply_modifier_at_load = True
143+
self.use_modifier_cache = True
144144
if self.modifier is not None:
145-
if hasattr(self.modifier, "apply_modifier_at_load"):
146-
self.apply_modifier_at_load = self.modifier.apply_modifier_at_load
147-
# Cache for modified frames when apply_modifier_at_load is True
145+
if hasattr(self.modifier, "use_cache"):
146+
self.use_modifier_cache = self.modifier.use_cache
147+
# Cache for modified frames when use_modifier_cache is True
148148
self._modified_frame_cache = {}
149149

150150
def add(
@@ -385,9 +385,9 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray:
385385

386386
def get_single_frame(self, index: int) -> dict:
387387
"""Orchestrates loading a single frame efficiently using memmap."""
388-
# Check if we have a cached modified frame and apply_modifier_at_load is True
388+
# Check if we have a cached modified frame and use_modifier_cache is True
389389
if (
390-
self.apply_modifier_at_load
390+
self.use_modifier_cache
391391
and self.modifier is not None
392392
and index in self._modified_frame_cache
393393
):
@@ -490,19 +490,18 @@ def get_single_frame(self, index: int) -> dict:
490490
if self.modifier is not None:
491491
# Apply modifier if it exists
492492
self.modifier.modify_data(frame_data, self)
493-
if self.apply_modifier_at_load:
493+
if self.use_modifier_cache:
494494
# Cache the modified frame to avoid recomputation
495495
self._modified_frame_cache[index] = copy.deepcopy(frame_data)
496-
497496
return frame_data
498497

499498
def preload_and_modify_all_data(self) -> None:
500499
"""Preload all frames and apply modifier to cache them.
501500
502-
This method is useful when apply_modifier_at_load is True and you want to
501+
This method is useful when use_modifier_cache is True and you want to
503502
avoid applying the modifier repeatedly during training.
504503
"""
505-
if not self.apply_modifier_at_load or self.modifier is None:
504+
if not self.use_modifier_cache or self.modifier is None:
506505
return
507506

508507
log.info("Preloading and modifying all data frames...")

source/tests/pt/test_data_modifier.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,30 @@
6464
@modifier_args_plugin.register("random_tester", doc=doc_random_tester)
6565
def modifier_random_tester() -> list:
6666
doc_seed = "Random seed used to initialize the random number generator for deterministic scaling factors."
67+
doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation."
6768
return [
6869
Argument("seed", int, optional=True, doc=doc_seed),
70+
Argument("use_cache", bool, optional=True, doc=doc_use_cache),
6971
]
7072

7173

7274
@modifier_args_plugin.register("zero_tester", doc=doc_zero_tester)
7375
def modifier_zero_tester() -> list:
74-
return []
76+
doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation."
77+
return [
78+
Argument("use_cache", bool, optional=True, doc=doc_use_cache),
79+
]
7580

7681

7782
@modifier_args_plugin.register("scaling_tester", doc=doc_scaling_tester)
7883
def modifier_scaling_tester() -> list[Argument]:
7984
doc_model_name = "The name of the frozen energy model file."
8085
doc_sfactor = "The scaling factor for correction."
86+
doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation."
8187
return [
8288
Argument("model_name", str, optional=False, doc=doc_model_name),
8389
Argument("sfactor", float, optional=False, doc=doc_sfactor),
90+
Argument("use_cache", bool, optional=True, doc=doc_use_cache),
8491
]
8592

8693

@@ -92,12 +99,14 @@ def __new__(cls, *args, **kwargs):
9299
def __init__(
93100
self,
94101
seed: int = 1,
102+
use_cache: bool = True,
95103
) -> None:
96104
"""Construct a random_tester modifier that scales data by deterministic random factors for testing."""
97-
super().__init__()
105+
super().__init__(use_cache)
98106
self.modifier_type = "random_tester"
99107
# Use a fixed seed for deterministic behavior
100108
self.rng = np.random.default_rng(seed)
109+
self.sfactor = self.rng.random()
101110

102111
def forward(
103112
self,
@@ -121,21 +130,24 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N
121130
return
122131

123132
if "find_energy" in data and data["find_energy"] == 1.0:
124-
data["energy"] = data["energy"] * self.rng.random()
133+
data["energy"] = data["energy"] * self.sfactor
125134
if "find_force" in data and data["find_force"] == 1.0:
126-
data["force"] = data["force"] * self.rng.random()
135+
data["force"] = data["force"] * self.sfactor
127136
if "find_virial" in data and data["find_virial"] == 1.0:
128-
data["virial"] = data["virial"] * self.rng.random()
137+
data["virial"] = data["virial"] * self.sfactor
129138

130139

131140
@BaseModifier.register("zero_tester")
132141
class ModifierZeroTester(BaseModifier):
133142
def __new__(cls, *args, **kwargs):
134143
return super().__new__(cls)
135144

136-
def __init__(self) -> None:
145+
def __init__(
146+
self,
147+
use_cache: bool = True,
148+
) -> None:
137149
"""Construct a modifier that zeros out data for testing."""
138-
super().__init__()
150+
super().__init__(use_cache)
139151
self.modifier_type = "zero_tester"
140152

141153
def forward(
@@ -176,9 +188,10 @@ def __init__(
176188
self,
177189
model_name: str,
178190
sfactor: float = 1.0,
191+
use_cache: bool = True,
179192
) -> None:
180193
"""Initialize a test modifier that applies scaled model predictions using a frozen model."""
181-
super().__init__()
194+
super().__init__(use_cache)
182195
self.modifier_type = "scaling_tester"
183196
self.model_name = model_name
184197
self.sfactor = sfactor
@@ -212,6 +225,7 @@ def forward(
212225
@parameterized(
213226
(1, 2), # training data batch_size
214227
(1, 2), # validation data batch_size
228+
(True, False), # use_cache
215229
)
216230
class TestDataModifier(unittest.TestCase):
217231
def setUp(self) -> None:
@@ -240,7 +254,10 @@ def test_init_modify_data(self):
240254
"""Ensure modify_data applied."""
241255
tmp_config = self.config.copy()
242256
# add tester data modifier
243-
tmp_config["model"]["modifier"] = {"type": "zero_tester"}
257+
tmp_config["model"]["modifier"] = {
258+
"type": "zero_tester",
259+
"use_cache": self.param[2],
260+
}
244261

245262
# data modification is finished in __init__
246263
trainer = get_trainer(tmp_config)
@@ -262,6 +279,7 @@ def test_full_modify_data(self):
262279
tmp_config["model"]["modifier"] = {
263280
"type": "random_tester",
264281
"seed": 1024,
282+
"use_cache": self.param[2],
265283
}
266284

267285
# data modification is finished in __init__
@@ -307,6 +325,7 @@ def test_inference(self):
307325
"type": "scaling_tester",
308326
"model_name": "frozen_model_dm.pth",
309327
"sfactor": sfactor,
328+
"use_cache": True,
310329
}
311330

312331
trainer = get_trainer(tmp_config)

0 commit comments

Comments
 (0)