@@ -153,14 +153,21 @@ def _trace_and_compile(
153153) -> torch .nn .Module :
154154 """Trace ``forward_lower`` with ``make_fx`` and compile with ``torch.compile``.
155155
156+ Uses symbolic tracing (``tracing_mode="symbolic"``) so the resulting
157+ FX graph captures shape-polymorphic operations. The graph is then
158+ compiled with ``torch.compile(dynamic=True)`` and the inductor
159+ backend, which automatically pads tensor shapes for efficient kernel
160+ execution (``shape_padding=True``).
161+
156162 Parameters
157163 ----------
158164 model : torch.nn.Module
159- The (uncompiled) model. Temporarily set to eval mode for tracing.
165+ The (uncompiled) model.
160166 ext_coord, ext_atype, nlist, mapping, fparam, aparam
161- Sample tensors (already padded to the desired max_nall) .
167+ Sample tensors used to drive the symbolic trace .
162168 compile_opts : dict
163- Options forwarded to ``torch.compile`` (excluding ``dynamic``).
169+ Options forwarded to ``torch.compile``. Keys ``dynamic`` and
170+ ``backend`` are set internally and ignored if provided.
164171
165172 Returns
166173 -------
@@ -197,84 +204,52 @@ def fn(
197204 aparam = aparam ,
198205 )
199206
200- # Use default tracing_mode="real" (concrete shapes) for best
201- # runtime performance. If data-dependent intermediate shapes
202- # change at runtime, the caller catches the error and retraces.
203- traced_lower = make_fx (fn )(ext_coord , ext_atype , nlist , mapping , fparam , aparam )
207+ # Symbolic tracing captures shape-polymorphic ops, pairing with
208+ # dynamic=True in torch.compile to handle varying nall without
209+ # manual padding or recompilation.
210+ traced_lower = make_fx (
211+ fn ,
212+ tracing_mode = "symbolic" ,
213+ _allow_non_fake_inputs = True ,
214+ )(ext_coord , ext_atype , nlist , mapping , fparam , aparam )
204215
205216 if not was_training :
206217 model .eval ()
207218
208- # The inductor backend does not propagate gradients through the
209- # make_fx-decomposed autograd.grad ops (second-order gradients for
210- # force training). Use "aot_eager" which correctly preserves the
211- # gradient chain while still benefiting from make_fx decomposition.
212- if "backend" not in compile_opts :
213- compile_opts ["backend" ] = "aot_eager"
214- compiled_lower = torch .compile (traced_lower , dynamic = False , ** compile_opts )
219+ # Override backend and dynamic — the inductor backend with
220+ # dynamic=True handles varying shapes automatically.
221+ compile_opts .pop ("dynamic" , None )
222+ compile_opts .pop ("backend" , None )
223+ if "options" not in compile_opts :
224+ compile_opts ["options" ] = {}
225+ compile_opts ["options" ].setdefault ("shape_padding" , True )
226+
227+ compiled_lower = torch .compile (
228+ traced_lower ,
229+ backend = "inductor" ,
230+ dynamic = True ,
231+ ** compile_opts ,
232+ )
215233 return compiled_lower
216234
217235
218236class _CompiledModel (torch .nn .Module ):
219- """Coord extension (eager) -> pad nall -> compiled forward_lower.
237+ """Coord extension (eager) -> compiled forward_lower.
220238
221- If a batch's ``nall`` exceeds the current ``max_nall``, the model is
222- automatically re-traced and recompiled with a larger pad size.
239+ Coord extension and neighbor list construction involve data-dependent
240+ control flow and are kept in eager mode. The compiled ``forward_lower``
241+ handles varying ``nall`` via ``dynamic=True`` — no manual padding or
242+ recompilation needed.
223243 """
224244
225245 def __init__ (
226246 self ,
227247 original_model : torch .nn .Module ,
228248 compiled_forward_lower : torch .nn .Module ,
229- max_nall : int ,
230- compile_opts : dict [str , Any ],
231249 ) -> None :
232250 super ().__init__ ()
233251 self .original_model = original_model
234252 self .compiled_forward_lower = compiled_forward_lower
235- self ._max_nall = max_nall
236- self ._compile_opts = compile_opts
237-
238- def _recompile (
239- self ,
240- ext_coord : torch .Tensor ,
241- ext_atype : torch .Tensor ,
242- nlist : torch .Tensor ,
243- mapping : torch .Tensor ,
244- fparam : torch .Tensor | None ,
245- aparam : torch .Tensor | None ,
246- new_max_nall : int ,
247- ) -> None :
248- """Re-trace and recompile for the given inputs.
249-
250- If *new_max_nall* differs from the current ``_max_nall``, the
251- inputs are padded (or already padded by the caller).
252- """
253- # Pad if the caller provides unpadded tensors (nall growth case)
254- actual_nall = ext_coord .shape [1 ]
255- pad_n = new_max_nall - actual_nall
256- if pad_n > 0 :
257- ext_coord = torch .nn .functional .pad (ext_coord , (0 , 0 , 0 , pad_n ))
258- ext_atype = torch .nn .functional .pad (ext_atype , (0 , pad_n ))
259- mapping = torch .nn .functional .pad (mapping , (0 , pad_n ))
260-
261- ext_coord = ext_coord .detach ()
262-
263- self .compiled_forward_lower = _trace_and_compile (
264- self .original_model ,
265- ext_coord ,
266- ext_atype ,
267- nlist ,
268- mapping ,
269- fparam ,
270- aparam ,
271- self ._compile_opts ,
272- )
273- self ._max_nall = new_max_nall
274- log .info (
275- "Recompiled model with max_nall=%d." ,
276- new_max_nall ,
277- )
278253
279254 def forward (
280255 self ,
@@ -318,27 +293,6 @@ def forward(
318293 distinguish_types = False ,
319294 )
320295 ext_coord = ext_coord .reshape (nframes , - 1 , 3 )
321-
322- # Grow max_nall if needed (retrace + recompile)
323- actual_nall = ext_coord .shape [1 ]
324- if actual_nall > self ._max_nall :
325- new_max_nall = ((int (actual_nall * 1.2 ) + 7 ) // 8 ) * 8
326- log .info (
327- "nall=%d exceeds max_nall=%d; recompiling with max_nall=%d." ,
328- actual_nall ,
329- self ._max_nall ,
330- new_max_nall ,
331- )
332- self ._recompile (
333- ext_coord , ext_atype , nlist , mapping , fparam , aparam , new_max_nall
334- )
335-
336- # Pad to max_nall so compiled graph sees a fixed shape
337- pad_n = self ._max_nall - actual_nall
338- if pad_n > 0 :
339- ext_coord = torch .nn .functional .pad (ext_coord , (0 , 0 , 0 , pad_n ))
340- ext_atype = torch .nn .functional .pad (ext_atype , (0 , pad_n ))
341- mapping = torch .nn .functional .pad (mapping , (0 , pad_n ))
342296 ext_coord = ext_coord .detach ().requires_grad_ (True )
343297
344298 result = self .compiled_forward_lower (
@@ -350,22 +304,18 @@ def forward(
350304 # Ghost-atom forces must be scatter-summed back to local atoms
351305 # via ``mapping`` — the same operation ``communicate_extended_output``
352306 # performs in the uncompiled path.
307+ actual_nall = ext_coord .shape [1 ]
353308 out : dict [str , torch .Tensor ] = {}
354309 out ["atom_energy" ] = result ["atom_energy" ]
355310 out ["energy" ] = result ["energy" ]
356311 if "extended_force" in result :
357- ext_force = result ["extended_force" ] # (nf, nall_padded, 3)
358- # mapping may be padded; only use actual_nall entries
359- map_actual = mapping [:, :actual_nall ] # (nf, actual_nall)
360- ext_force_actual = ext_force [:, :actual_nall , :] # (nf, actual_nall, 3)
312+ ext_force = result ["extended_force" ] # (nf, nall, 3)
361313 # scatter-sum extended forces onto local atoms
362- idx = map_actual .unsqueeze (- 1 ).expand_as (
363- ext_force_actual
364- ) # (nf, actual_nall, 3)
314+ idx = mapping .unsqueeze (- 1 ).expand_as (ext_force ) # (nf, nall, 3)
365315 force = torch .zeros (
366316 nframes , nloc , 3 , dtype = ext_force .dtype , device = ext_force .device
367317 )
368- force .scatter_add_ (1 , idx , ext_force_actual )
318+ force .scatter_add_ (1 , idx , ext_force )
369319 out ["force" ] = force
370320 if "virial" in result :
371321 out ["virial" ] = result ["virial" ]
@@ -642,21 +592,19 @@ def get_sample() -> list[dict[str, np.ndarray]]:
642592 def _compile_model (self , compile_opts : dict [str , Any ]) -> None :
643593 """Replace ``self.model`` with a compiled version.
644594
645- The model's ``forward`` uses ``torch.autograd.grad`` (for force
646- computation) with ``create_graph=True``, which creates a "double
647- backward" that ``torch.compile`` cannot handle.
648-
649- Solution: use ``make_fx`` to trace ``forward_lower``, decomposing
650- ``torch.autograd.grad`` into primitive ops. The coord extension +
651- nlist build (data-dependent control flow) are kept outside the
652- compiled region.
653-
654- To avoid the overhead of symbolic tracing and dynamic shapes, the
655- extended-atom dimension (nall) is padded to a fixed maximum
656- estimated from the training data. This allows concrete-shape
657- tracing and ``dynamic=False``. If a batch exceeds the current
658- max_nall at runtime, the model is automatically re-traced and
659- recompiled with a larger pad size.
595+ The model's ``forward`` uses ``torch.autograd.grad`` (for forces)
596+ with ``create_graph=True``, which creates a "double backward" that
597+ ``torch.compile`` cannot handle.
598+
599+ Solution: use ``make_fx`` with ``tracing_mode="symbolic"`` to trace
600+ ``forward_lower``, decomposing ``torch.autograd.grad`` into
601+ primitive ops with symbolic shapes. The traced graph is compiled
602+ with ``torch.compile(dynamic=True, backend="inductor")`` so
603+ varying ``nall`` across batches is handled automatically — no
604+ manual padding or recompilation needed.
605+
606+ Coord extension + nlist build (data-dependent control flow) are
607+ kept outside the compiled region.
660608 """
661609 from deepmd .dpmodel .utils .nlist import (
662610 build_neighbor_list ,
@@ -668,105 +616,53 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
668616
669617 model = self .model
670618
671- # --- Estimate max_nall by sampling multiple batches ---
672- n_sample = 20
673- max_nall = 0
674- best_sample : (
675- tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , int , dict ] | None
676- ) = None
677-
678- for _ii in range (n_sample ):
679- inp , _ = self .get_data (is_train = True )
680- coord = inp ["coord" ].detach ()
681- atype = inp ["atype" ].detach ()
682- box = inp .get ("box" )
683- if box is not None :
684- box = box .detach ()
685-
686- nframes , nloc = atype .shape [:2 ]
687- coord_np = coord .cpu ().numpy ().reshape (nframes , nloc , 3 )
688- atype_np = atype .cpu ().numpy ()
689- box_np = box .cpu ().numpy ().reshape (nframes , 9 ) if box is not None else None
690-
691- if box_np is not None :
692- coord_norm = normalize_coord (coord_np , box_np .reshape (nframes , 3 , 3 ))
693- else :
694- coord_norm = coord_np
619+ # --- Get one sample batch to drive the symbolic trace ---
620+ inp , _ = self .get_data (is_train = True )
621+ coord = inp ["coord" ].detach ()
622+ atype = inp ["atype" ].detach ()
623+ box = inp .get ("box" )
624+ if box is not None :
625+ box = box .detach ()
695626
696- ext_coord_np , ext_atype_np , mapping_np = extend_coord_with_ghosts (
697- coord_norm , atype_np , box_np , model .get_rcut ()
698- )
699- nlist_np = build_neighbor_list (
700- ext_coord_np ,
701- ext_atype_np ,
702- nloc ,
703- model .get_rcut (),
704- model .get_sel (),
705- distinguish_types = False ,
706- )
707- ext_coord_np = ext_coord_np .reshape (nframes , - 1 , 3 )
708- nall = ext_coord_np .shape [1 ]
709- if nall > max_nall :
710- max_nall = nall
711- best_sample = (
712- ext_coord_np ,
713- ext_atype_np ,
714- mapping_np ,
715- nlist_np ,
716- nloc ,
717- inp ,
718- )
627+ nframes , nloc = atype .shape [:2 ]
628+ coord_3d = coord .reshape (nframes , nloc , 3 )
629+ box_flat = box .reshape (nframes , 9 ) if box is not None else None
719630
720- # Add 20 % margin and round up to a multiple of 8.
721- max_nall = ((int (max_nall * 1.2 ) + 7 ) // 8 ) * 8
722- log .info (
723- "Estimated max_nall=%d for compiled model (sampled %d batches)." ,
724- max_nall ,
725- n_sample ,
726- )
631+ if box_flat is not None :
632+ coord_norm = normalize_coord (coord_3d , box_flat .reshape (nframes , 3 , 3 ))
633+ else :
634+ coord_norm = coord_3d
727635
728- # --- Pad the largest sample to max_nall and trace ---
729- assert best_sample is not None
730- ext_coord_np , ext_atype_np , mapping_np , nlist_np , nloc , sample_input = (
731- best_sample
636+ ext_coord , ext_atype , mapping = extend_coord_with_ghosts (
637+ coord_norm , atype , box_flat , model .get_rcut ()
732638 )
733- nframes = ext_coord_np .shape [0 ]
734- actual_nall = ext_coord_np .shape [1 ]
735- pad_n = max_nall - actual_nall
736-
737- if pad_n > 0 :
738- ext_coord_np = np .pad (ext_coord_np , ((0 , 0 ), (0 , pad_n ), (0 , 0 )))
739- ext_atype_np = np .pad (ext_atype_np , ((0 , 0 ), (0 , pad_n )))
740- mapping_np = np .pad (mapping_np , ((0 , 0 ), (0 , pad_n )))
741-
742- ext_coord = torch .tensor (
743- ext_coord_np , dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
639+ nlist_t = build_neighbor_list (
640+ ext_coord ,
641+ ext_atype ,
642+ nloc ,
643+ model .get_rcut (),
644+ model .get_sel (),
645+ distinguish_types = False ,
744646 )
745- ext_atype = torch .tensor (ext_atype_np , dtype = torch .int64 , device = DEVICE )
746- nlist_t = torch .tensor (nlist_np , dtype = torch .int64 , device = DEVICE )
747- mapping_t = torch .tensor (mapping_np , dtype = torch .int64 , device = DEVICE )
748- fparam = sample_input .get ("fparam" )
749- aparam = sample_input .get ("aparam" )
647+ ext_coord = ext_coord .reshape (nframes , - 1 , 3 )
750648
751- compile_opts .pop ("dynamic" , None ) # always False for padded approach
649+ fparam = inp .get ("fparam" )
650+ aparam = inp .get ("aparam" )
752651
753652 compiled_lower = _trace_and_compile (
754653 model ,
755654 ext_coord ,
756655 ext_atype ,
757656 nlist_t ,
758- mapping_t ,
657+ mapping ,
759658 fparam ,
760659 aparam ,
761660 compile_opts ,
762661 )
763662
764- self .wrapper .model = _CompiledModel (
765- model , compiled_lower , max_nall , compile_opts
766- )
663+ self .wrapper .model = _CompiledModel (model , compiled_lower )
767664 log .info (
768- "Model compiled with padded nall=%d (tracing_mode=real, dynamic=False)." ,
769- max_nall ,
665+ "Model compiled (tracing_mode=symbolic, dynamic=True, backend=inductor)." ,
770666 )
771667
772668 # ------------------------------------------------------------------
0 commit comments