File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -439,10 +439,9 @@ def __init__(
439439 self .original_model = original_model
440440 self .compiled_forward_lower = compiled_forward_lower
441441 self ._task_buf_order = task_buf_order
442- if task_buf_order and task_buffers :
443- for name in task_buf_order :
444- if name in task_buffers :
445- self .register_buffer (f"_task_{ name } " , task_buffers [name ])
442+ # task_buffers is intentionally not stored: buffers are read from
443+ # original_model.get_fitting_net() at forward time so that weight
444+ # updates (load_state_dict, optimiser steps) are always reflected.
446445
447446 def forward (
448447 self ,
@@ -489,9 +488,16 @@ def forward(
489488 ext_coord = ext_coord .reshape (nframes , - 1 , 3 )
490489 ext_coord = ext_coord .detach ().requires_grad_ (True )
491490
492- task_buf_vals = tuple (
493- getattr (self , f"_task_{ name } " ) for name in self ._task_buf_order
494- )
491+ if self ._task_buf_order :
492+ try :
493+ _fitting = self .original_model .get_fitting_net ()
494+ task_buf_vals : tuple = tuple (
495+ getattr (_fitting , name ) for name in self ._task_buf_order
496+ )
497+ except AttributeError :
498+ task_buf_vals = ()
499+ else :
500+ task_buf_vals = ()
495501 result = self .compiled_forward_lower (
496502 ext_coord ,
497503 ext_atype ,
You can’t perform that action at this time.
0 commit comments