File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -366,7 +366,7 @@ def _eval_model(
366366 out = np .full (shape , np .nan , dtype = GLOBAL_NP_FLOAT_PRECISION )
367367 results .append (out )
368368 else :
369- shape = self ._get_output_shape (odef , 1 , 1 )
369+ shape = self ._get_output_shape (odef , nframes , natoms )
370370 results .append (
371371 np .full (np .abs (shape ), np .nan , dtype = GLOBAL_NP_FLOAT_PRECISION )
372372 ) # this is kinda hacky
Original file line number Diff line number Diff line change @@ -389,7 +389,7 @@ def _eval_model(
389389 out = np .full (shape , np .nan , dtype = GLOBAL_NP_FLOAT_PRECISION )
390390 results .append (out )
391391 else :
392- shape = self ._get_output_shape (odef , 1 , 1 )
392+ shape = self ._get_output_shape (odef , nframes , natoms )
393393 results .append (
394394 np .full (np .abs (shape ), np .nan , dtype = GLOBAL_NP_FLOAT_PRECISION )
395395 ) # this is kinda hacky
Original file line number Diff line number Diff line change @@ -577,7 +577,7 @@ def _eval_model(
577577 out = out .numpy ()
578578 results .append (out )
579579 else :
580- shape = self ._get_output_shape (odef , 1 , 1 )
580+ shape = self ._get_output_shape (odef , nframes , natoms )
581581 results .append (
582582 np .full (np .abs (shape ), np .nan , dtype = prec )
583583 ) # this is kinda hacky
@@ -657,7 +657,7 @@ def _eval_model_spin(
657657 out = batch_output [pd_name ].reshape (shape ).detach ().cpu ().numpy ()
658658 results .append (out )
659659 else :
660- shape = self ._get_output_shape (odef , 1 , 1 )
660+ shape = self ._get_output_shape (odef , nframes , natoms )
661661 results .append (
662662 np .full (
663663 np .abs (shape ),
Original file line number Diff line number Diff line change @@ -528,7 +528,7 @@ def _eval_model(
528528 out = batch_output [pt_name ].reshape (shape ).detach ().cpu ().numpy ()
529529 results .append (out )
530530 else :
531- shape = self ._get_output_shape (odef , 1 , 1 )
531+ shape = self ._get_output_shape (odef , nframes , natoms )
532532 results .append (
533533 np .full (np .abs (shape ), np .nan , dtype = prec )
534534 ) # this is kinda hacky
@@ -608,7 +608,7 @@ def _eval_model_spin(
608608 out = batch_output [pt_name ].reshape (shape ).detach ().cpu ().numpy ()
609609 results .append (out )
610610 else :
611- shape = self ._get_output_shape (odef , 1 , 1 )
611+ shape = self ._get_output_shape (odef , nframes , natoms )
612612 results .append (
613613 np .full (
614614 np .abs (shape ),
You can’t perform that action at this time.
0 commit comments