8787function:
8888
8989* ``core_compute`` rebuilds a compact, GPU-friendly edge list from the
90- padded DeePMD neighbor list (``build_edge_list_from_nlist``), with a
91- single masked dummy edge appended so the edge tensor is never empty
92- (NOTE 10). Edge vectors come from ``index_select`` on the extended
93- coordinate tensor, which keeps the gradient path back to coordinates
94- explicit and safe under symbolic shapes (NOTE 11).
90+ padded DeePMD neighbor list (``build_edge_list_from_nlist``), with
91+ masked dummy edges appended so the edge tensor has a non-singular
92+ symbolic lower bound (NOTE 10). Edge vectors come from
93+ ``index_select`` on the extended coordinate tensor, which keeps the
94+ gradient path back to coordinates explicit and safe under symbolic
95+ shapes (NOTE 11).
9596* The SeZM descriptor consumes the edge list and produces per-atom
9697 features.
9798* The fitting network predicts per-atom energy; ``apply_out_stat`` adds
240241 cudagraphs capture autograd metadata only once. Higher-order
241242 gradients need fresh metadata per call, so cudagraphs would feed
242243 stale autograd state into the second backward.
243- * ``max_fusion_size`` -- mode-dependent
244+ * ``max_fusion_size=8``
244245 Caps kernel fusion complexity so Inductor's scheduler does not
245246 time out on the large edge-level reductions inside the
246- descriptor when nsel is big. Training uses ``64`` (the long-
247- standing default, observed stable on every training run so far);
248- inference uses the tighter ``8`` to dodge the Triton lowering
249- failure described by the next bullet .
250- * ``triton.persistent_reductions=False`` -- inference only
247+ descriptor when nsel is big. The tighter value keeps both
248+ training and inference fusions small enough for Triton IR
249+ generation on GPU backends that are sensitive to large dynamic
250+ edge graphs .
251+ * ``triton.persistent_reductions=False``
251252 Inductor's persistent-reduction scheduler fuses a ``sum`` with
252253 *all* neighbouring pointwise ops (``tanh_backward``, ``pow``,
253254 ``exp``, ``mul``, ``select``, ``slice``, ``view`` ...) into one
254- ``triton_per_fused_...`` kernel. On the graph emitted by
255- inference (``create_graph=False``, no double-detach stripping,
256- different fused topology than training) this kernel hits Triton
257- bug ``PassManager::run failed`` inside ``make_ttgir``. Training
258- never produces the same fused shape and does not benefit from
259- disabling the optimisation, so the flag is left on for training
260- to preserve kernel quality.
255+ ``triton_per_fused_...`` kernel. On SeZM's dynamic edge graph
256+ this can hit Triton bug ``PassManager::run failed`` inside
257+ ``make_ttgir``. Disabling it forces the reduction into its own
258+ kernel before either training or inference can form the
259+ pathological fused IR.
261260* ``triton.mix_order_reduction=False``
262261 Workaround for PyTorch <=2.11 bugs pytorch/pytorch#174379,
263262 #178080, #179494. All three manifest only under data-dependent
324323In eval mode we merely detach; no ``create_graph`` is requested, so the
325324compiled kernel never has to build a backward graph.
326325
327- NOTE 10 -- Tail dummy edge
328- --------------------------
326+ NOTE 10 -- Tail dummy edges
327+ ---------------------------
329328
330- ``build_edge_list_from_nlist`` appends exactly one masked edge at the
331- end of every batch. Real edge compaction happens via
329+ ``build_edge_list_from_nlist`` appends two masked edges at the end of
330+ every batch. Real edge compaction happens via
332331``torch.nonzero(valid_mask)``, whose output length is data-dependent
333332and can be zero in sparse or single-type systems. make_fx cannot trace
334333an "if n_edges == 0: skip" branch symbolically; without the dummy it
335334would fall back to concrete shape specialization and break
336- ``dynamic=True``. The dummy's ``edge_mask`` is ``False`` so it
337- contributes exactly zero to every downstream sum or gather.
335+ ``dynamic=True``. A pair of dummy slots also gives Inductor's batched
336+ matmul lowering a static ``E >= 2`` edge-axis bound, avoiding
337+ data-dependent layout guards on ``E == 1``. Each dummy's ``edge_mask``
338+ is ``False`` so it contributes exactly zero to every downstream sum or
339+ gather.
338340
339341NOTE 11 -- ``index_select`` for coordinate gradients
340342----------------------------------------------------
@@ -1690,44 +1692,23 @@ def compute_fn(
16901692 # fresh graph is cheap and a segfault is fatal.
16911693 traced = _rebuild_graph_module (traced )
16921694
1693- # NOTE: Inductor options are mode-dependent. Training has been
1694- # running cleanly with ``max_fusion_size=64`` for a while, so we
1695- # keep that path untouched to avoid destabilising it. Inference
1696- # (``self.training is False``) has shown a Triton
1697- # ``make_ttgir`` / ``PassManager::run failed`` on the fused
1698- # per-reduction kernel
1699- # ``triton_per_fused_clone_exp_mul_pow_select_slice_sum_tanh_...``;
1700- # the kernel itself is fine, but the *fused* IR is too big /
1701- # too complex for Triton's lowering pipeline on this version.
1702- # So inference:
1703- # * disables ``triton.persistent_reductions`` -- persistent
1704- # reduction is what lets Inductor pull a ``sum`` together
1705- # with all surrounding pointwise ops (including the
1706- # activation-backward pointwise chain) into one
1707- # ``per_fused_...`` kernel; turning it off forces the sum
1708- # to emit its own kernel and stops the pathological fuse.
1709- # * tightens ``max_fusion_size`` from 64 to 8, so even
1710- # non-persistent fusions stay small enough for Triton IR
1711- # generation to succeed.
1712- # Training does not hit this path in practice (different graph
1713- # topology under ``create_graph=True``), so we keep the looser
1714- # options there to preserve kernel quality.
1695+ # NOTE: Conservative Inductor options keep SeZM's dynamic edge
1696+ # graph from forming overly large Triton reduction kernels
1697+ # (``make_ttgir`` / ``PassManager::run failed``) on some
1698+ # GPU/Triton combinations.
17151699 compile_options : dict [str , Any ] = {
17161700 "max_autotune" : False ,
17171701 "shape_padding" : True ,
17181702 "epilogue_fusion" : False ,
17191703 "triton.cudagraphs" : False ,
1704+ "max_fusion_size" : 8 ,
1705+ "triton.persistent_reductions" : False ,
17201706 # NOTE: ``mix_order_reduction`` hits multiple bugs under
17211707 # data-dependent symbolic shapes on PyTorch <=2.11
17221708 # (pytorch/pytorch#174379, #178080, #179494) -- our edge
17231709 # count is exactly that kind of shape.
17241710 "triton.mix_order_reduction" : False ,
17251711 }
1726- if self .training :
1727- compile_options ["max_fusion_size" ] = 64
1728- else :
1729- compile_options ["max_fusion_size" ] = 8
1730- compile_options ["triton.persistent_reductions" ] = False
17311712 try :
17321713 from torch ._inductor import config as inductor_config
17331714
@@ -1979,9 +1960,10 @@ def build_edge_list_from_nlist(
19791960 Build a compact edge list from DeePMD padded neighbor list.
19801961
19811962 Edge vectors are computed via ``index_select`` on ``extended_coord``
1982- so they remain differentiable w.r.t. the input coordinates. One
1983- masked dummy edge is always appended to avoid data-dependent empty-edge
1984- branches that ``make_fx`` cannot trace.
1963+ so they remain differentiable w.r.t. the input coordinates. Two
1964+ masked dummy edges are always appended to avoid data-dependent empty-edge
1965+ branches that ``make_fx`` cannot trace and singular edge-axis guards
1966+ in Inductor's batched matmul lowering.
19851967
19861968 Parameters
19871969 ----------
@@ -1995,11 +1977,11 @@ def build_edge_list_from_nlist(
19951977 Returns
19961978 -------
19971979 edge_index
1998- Edge indices with shape (2, E+1 ) where E is valid edge count.
1980+ Edge indices with shape (2, E+2 ) where E is valid edge count.
19991981 edge_vec
2000- Edge vectors with shape (E+1 , 3).
1982+ Edge vectors with shape (E+2 , 3).
20011983 edge_mask
2002- Boolean mask with shape (E+1, ). The trailing element is ``False``.
1984+ Boolean mask with shape (E+2 ). The trailing elements are ``False``.
20031985 """
20041986 nf , nloc , nsel = nlist .shape
20051987 n_actual = nf * nloc
@@ -2051,19 +2033,22 @@ def build_edge_list_from_nlist(
20512033
20522034 valid_idx = torch .nonzero (edge_mask_actual , as_tuple = False ).flatten ()
20532035
2054- # === Step 3. Compact edges + append one masked dummy ===
2055- # NOTE: Always append exactly one masked dummy edge .
2036+ # === Step 3. Compact edges + append masked dummies ===
2037+ # NOTE: Always append two masked dummy edges .
20562038 # ``torch.nonzero(edge_mask_actual)`` produces a data-dependent
20572039 # number of valid edges, which can be zero on sparse or
20582040 # single-type systems. make_fx cannot trace an
20592041 # ``if n_edges == 0: skip`` branch symbolically; without the
20602042 # dummy it would fall back to concrete shape specialisation and
2061- # break ``torch.compile(dynamic=True)`` for later batches. The
2043+ # break ``torch.compile(dynamic=True)`` for later batches. Two
2044+ # dummy edges keep the symbolic edge axis statically above one,
2045+ # which avoids Inductor bmm layout guards on ``E == 1``. Each
20622046 # dummy edge copies entry 0 (any in-range index is fine) and
20632047 # carries ``edge_mask=False`` so every downstream sum, gather
20642048 # or scatter ignores it.
2049+ dummy_count = 2
20652050 padded_idx = torch .cat (
2066- [valid_idx , torch .zeros (1 , dtype = torch .long , device = device )]
2051+ [valid_idx , torch .zeros (dummy_count , dtype = torch .long , device = device )]
20672052 )
20682053 src_sel = src_actual .index_select (0 , padded_idx )
20692054 dst_sel = dst_actual .index_select (0 , padded_idx )
@@ -2072,7 +2057,7 @@ def build_edge_list_from_nlist(
20722057 edge_mask = torch .cat (
20732058 [
20742059 torch .ones (valid_idx .shape [0 ], dtype = torch .bool , device = device ),
2075- torch .zeros (1 , dtype = torch .bool , device = device ),
2060+ torch .zeros (dummy_count , dtype = torch .bool , device = device ),
20762061 ]
20772062 )
20782063 return edge_index , edge_vec_sel , edge_mask
0 commit comments