Skip to content

Commit 2beac89

Browse files
committed
Revert "fix: unnecessary memory alloc for large systems"
This reverts commit d8269cb.
1 parent d8269cb commit 2beac89

4 files changed

Lines changed: 6 additions & 6 deletions

File tree

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def _eval_model(
366366
out = np.full(shape, np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION)
367367
results.append(out)
368368
else:
369-
shape = self._get_output_shape(odef, 1, 1)
369+
shape = self._get_output_shape(odef, nframes, natoms)
370370
results.append(
371371
np.full(np.abs(shape), np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION)
372372
) # this is kinda hacky

deepmd/jax/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _eval_model(
389389
out = np.full(shape, np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION)
390390
results.append(out)
391391
else:
392-
shape = self._get_output_shape(odef, 1, 1)
392+
shape = self._get_output_shape(odef, nframes, natoms)
393393
results.append(
394394
np.full(np.abs(shape), np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION)
395395
) # this is kinda hacky

deepmd/pd/infer/deep_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def _eval_model(
577577
out = out.numpy()
578578
results.append(out)
579579
else:
580-
shape = self._get_output_shape(odef, 1, 1)
580+
shape = self._get_output_shape(odef, nframes, natoms)
581581
results.append(
582582
np.full(np.abs(shape), np.nan, dtype=prec)
583583
) # this is kinda hacky
@@ -657,7 +657,7 @@ def _eval_model_spin(
657657
out = batch_output[pd_name].reshape(shape).detach().cpu().numpy()
658658
results.append(out)
659659
else:
660-
shape = self._get_output_shape(odef, 1, 1)
660+
shape = self._get_output_shape(odef, nframes, natoms)
661661
results.append(
662662
np.full(
663663
np.abs(shape),

deepmd/pt/infer/deep_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def _eval_model(
528528
out = batch_output[pt_name].reshape(shape).detach().cpu().numpy()
529529
results.append(out)
530530
else:
531-
shape = self._get_output_shape(odef, 1, 1)
531+
shape = self._get_output_shape(odef, nframes, natoms)
532532
results.append(
533533
np.full(np.abs(shape), np.nan, dtype=prec)
534534
) # this is kinda hacky
@@ -608,7 +608,7 @@ def _eval_model_spin(
608608
out = batch_output[pt_name].reshape(shape).detach().cpu().numpy()
609609
results.append(out)
610610
else:
611-
shape = self._get_output_shape(odef, 1, 1)
611+
shape = self._get_output_shape(odef, nframes, natoms)
612612
results.append(
613613
np.full(
614614
np.abs(shape),

0 commit comments

Comments
 (0)