7070
7171log = logging .getLogger (__name__ )
7272
73- # Buffer names that differ per task after share_params; everything else in the
74- # fitting net is literally the same Python object across shared tasks.
73+ # Buffer names in the fitting net that differ per task after share_params;
74+ # everything else in the fitting net is the same Python object across tasks.
7575_TASK_SPECIFIC_BUFFER_NAMES : tuple [str , ...] = ("bias_atom_e" , "case_embd" )
7676
77+ # Buffer names in atomic_model that are per-task (energy/output statistics).
78+ # These live one level above the fitting net and are not reached by
79+ # fitting-net share_params, so they must also be promoted to FX placeholders.
80+ _ATOMIC_MODEL_TASK_BUFFER_NAMES : tuple [str , ...] = ("out_bias" , "out_std" )
81+
82+ # Prefix used in task_buf_order keys to distinguish atomic_model buffers
83+ # from fitting-net buffers.
84+ _AM_PREFIX = "am/"
85+
7786
7887def _get_task_buffers (model : torch .nn .Module ) -> dict [str , torch .Tensor ]:
79- """Return per-task fitting-net buffers that vary across shared tasks."""
88+ """Return per-task buffers (fitting net + atomic model) that vary across shared tasks."""
89+ result : dict [str , torch .Tensor ] = {}
90+ # fitting-net task buffers
8091 try :
8192 fitting = model .get_fitting_net ()
93+ for name in _TASK_SPECIFIC_BUFFER_NAMES :
94+ val = fitting ._buffers .get (name )
95+ if val is not None and torch .is_tensor (val ):
96+ result [name ] = val .detach ().clone ()
8297 except AttributeError :
83- return {}
84- result : dict [str , torch .Tensor ] = {}
85- for name in _TASK_SPECIFIC_BUFFER_NAMES :
86- val = getattr (fitting , name , None )
87- if val is not None and torch .is_tensor (val ):
88- result [name ] = val .detach ().clone ()
98+ pass
99+ # atomic_model task buffers (out_bias, out_std)
100+ try :
101+ am = model .atomic_model
102+ for name in _ATOMIC_MODEL_TASK_BUFFER_NAMES :
103+ val = am ._buffers .get (name )
104+ if val is not None and torch .is_tensor (val ):
105+ result [_AM_PREFIX + name ] = val .detach ().clone ()
106+ except AttributeError :
107+ pass
89108 return result
90109
91110
@@ -292,13 +311,18 @@ def _trace_and_compile(
292311 tuple (task_buffers [k ] for k in task_buf_order ) if task_buffers else ()
293312 )
294313
295- # Resolve fitting net once for buffer patching inside fn.
314+ # Resolve fitting net and atomic_model once for buffer patching inside fn.
296315 _fitting : torch .nn .Module | None = None
316+ _atomic_model : torch .nn .Module | None = None
297317 if task_buf_order :
298318 try :
299319 _fitting = model .get_fitting_net ()
300320 except AttributeError :
301- pass
321+ pass # no fitting net → no fitting-net buffers to patch
322+ try :
323+ _atomic_model = model .atomic_model
324+ except AttributeError :
325+ pass # no atomic_model → no atomic-model buffers to patch
302326
303327 def fn (
304328 extended_coord : torch .Tensor ,
@@ -313,12 +337,20 @@ def fn(
313337 extended_coord = extended_coord .detach ().requires_grad_ (True )
314338 # Temporarily patch task-specific buffers with the proxy tensors so
315339 # make_fx records them as FX placeholders rather than baked-in constants.
316- # This makes the compiled graph reusable for any buffer values.
340+ # Keys prefixed with _AM_PREFIX are atomic_model buffers; the rest are
341+ # fitting-net buffers.
317342 originals : dict [str , torch .Tensor | None ] = {}
318- if _fitting is not None and task_buf_order :
343+ if task_buf_order :
319344 for name , val in zip (task_buf_order , task_buf_vals ):
320- originals [name ] = _fitting ._buffers .get (name )
321- _fitting ._buffers [name ] = val
345+ if name .startswith (_AM_PREFIX ):
346+ actual = name [len (_AM_PREFIX ):]
347+ if _atomic_model is not None :
348+ originals [name ] = _atomic_model ._buffers .get (actual )
349+ _atomic_model ._buffers [actual ] = val
350+ else :
351+ if _fitting is not None :
352+ originals [name ] = _fitting ._buffers .get (name )
353+ _fitting ._buffers [name ] = val
322354 try :
323355 return model .forward_lower (
324356 extended_coord ,
@@ -331,7 +363,13 @@ def fn(
331363 )
332364 finally :
333365 for name , orig in originals .items ():
334- _fitting ._buffers [name ] = orig
366+ if name .startswith (_AM_PREFIX ):
367+ actual = name [len (_AM_PREFIX ):]
368+ if _atomic_model is not None :
369+ _atomic_model ._buffers [actual ] = orig
370+ else :
371+ if _fitting is not None :
372+ _fitting ._buffers [name ] = orig
335373
336374 # Pick a trace-time nframes that's unlikely to collide with any other
337375 # tensor dim in the graph. The symbolic tracer merges symbols that
@@ -491,9 +529,15 @@ def forward(
491529 if self ._task_buf_order :
492530 try :
493531 _fitting = self .original_model .get_fitting_net ()
494- task_buf_vals : tuple = tuple (
495- getattr (_fitting , name ) for name in self ._task_buf_order
496- )
532+ _am = getattr (self .original_model , "atomic_model" , None )
533+ _vals : list [torch .Tensor ] = []
534+ for _name in self ._task_buf_order :
535+ if _name .startswith (_AM_PREFIX ):
536+ _actual = _name [len (_AM_PREFIX ):]
537+ _vals .append (_am ._buffers [_actual ])
538+ else :
539+ _vals .append (getattr (_fitting , _name ))
540+ task_buf_vals : tuple = tuple (_vals )
497541 except AttributeError :
498542 task_buf_vals = ()
499543 else :
0 commit comments