Skip to content

Commit eba1d24

Browse files
committed
resolve nitpick comments
1 parent 23eb6a5 commit eba1d24

2 files changed

Lines changed: 9 additions & 4 deletions

File tree

deepmd/utils/data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,13 @@ def get_single_frame(self, index: int, num_worker: int) -> dict:
502502
if self.modifier is not None:
503503
with ThreadPoolExecutor(max_workers=num_worker) as executor:
504504
# Apply modifier if it exists
505-
executor.submit(
505+
future = executor.submit(
506506
self.modifier.modify_data,
507507
frame_data,
508508
self,
509509
)
510+
# Wait for completion and propagate any exceptions
511+
future.result()
510512
if self.use_modifier_cache:
511513
# Cache the modified frame to avoid recomputation
512514
self._modified_frame_cache[index] = copy.deepcopy(frame_data)

source/tests/pt/test_data_modifier.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262

6363

6464
@modifier_args_plugin.register("random_tester", doc=doc_random_tester)
65-
def modifier_random_tester() -> list:
65+
def modifier_random_tester() -> list[Argument]:
6666
doc_seed = "Random seed used to initialize the random number generator for deterministic scaling factors."
6767
doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation."
6868
return [
@@ -72,7 +72,7 @@ def modifier_random_tester() -> list:
7272

7373

7474
@modifier_args_plugin.register("zero_tester", doc=doc_zero_tester)
75-
def modifier_zero_tester() -> list:
75+
def modifier_zero_tester() -> list[Argument]:
7676
doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation."
7777
return [
7878
Argument("use_cache", bool, optional=True, doc=doc_use_cache),
@@ -377,7 +377,10 @@ def test_inference(self):
377377
# expected: output_model - sfactor * output_modifier
378378
for ii in range(3):
379379
np.testing.assert_allclose(
380-
model_pred[ii], model_pred_ref[ii] - sfactor * modifier_pred[ii]
380+
model_pred[ii],
381+
model_pred_ref[ii] - sfactor * modifier_pred[ii],
382+
rtol=1e-5,
383+
atol=1e-8,
381384
)
382385

383386
def tearDown(self) -> None:

0 commit comments

Comments
 (0)