fix: try fix dpa4 compile#5483
Conversation
for more information, see https://pre-commit.ci
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| val = getattr(fitting, aname, None) | ||
| if val is not None and torch.is_tensor(val): | ||
| names.append(_FIT_ATTR_PREFIX + aname) | ||
| except AttributeError: |
| names.append(_FIT_ATTR_PREFIX + aname) | ||
| except AttributeError: | ||
| pass | ||
| except AttributeError: |
There was a problem hiding this comment.
Pull request overview
This PR attempts to improve/repair the PyTorch-compiled execution path for the SeZM/DPA4 model, primarily by reducing recompiles/OOM in multi-task setups and addressing symbolic-shape tracing issues in make_fx.
Changes:
- Add module-level compile sharing and promote selected per-task buffers (e.g.,
out_bias,bias_atom_e,case_embd) as FX inputs to enable compiled-graph reuse across shared-parameter tasks. - Add additional symbolic-shape anti-aliasing logic for trace inputs and temporarily disable
ShapeEnvduck sizing during tracing. - Change edge-list construction to append a single masked dummy edge (instead of two) and adjust related documentation/behavior.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| aparam: torch.Tensor | None = None, | ||
| charge_spin: torch.Tensor | None = None, | ||
| *, | ||
| do_atomic_virial: bool = False, | ||
| charge_spin: torch.Tensor | None = None, | ||
| ) -> torch.nn.Module: |
| _ss_mod = None | ||
| _orig_se_init = None | ||
| try: | ||
| import torch.fx.experimental.symbolic_shapes as _ss_mod # type: ignore[no-redef] | ||
| except Exception: | ||
| _ss_mod = None | ||
| if _ss_mod is not None and hasattr(_ss_mod, "ShapeEnv"): | ||
| _orig_se_init = _ss_mod.ShapeEnv.__init__ | ||
|
|
||
| def _no_duck_shapeenv_init(self, *args, **kwargs): # type: ignore[no-untyped-def] | ||
| kwargs.setdefault("duck_shape", False) | ||
| return _orig_se_init(self, *args, **kwargs) | ||
|
|
||
| _ss_mod.ShapeEnv.__init__ = _no_duck_shapeenv_init | ||
| try: | ||
| traced = make_fx( | ||
| compute_fn, | ||
| tracing_mode="symbolic", | ||
| _allow_non_fake_inputs=True, | ||
| decomposition_table=decomp_table, | ||
| )(*trace_args) | ||
| finally: | ||
| if _orig_se_init is not None: | ||
| _ss_mod.ShapeEnv.__init__ = _orig_se_init |
| # === Step 3. Compact edges + append one masked dummy === | ||
| # NOTE: Always append exactly one masked dummy edge. | ||
| # ``torch.nonzero(edge_mask_actual)`` produces a data-dependent | ||
| # number of valid edges, which can be zero on sparse or | ||
| # single-type systems. make_fx cannot trace an |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5483 +/- ##
==========================================
- Coverage 81.34% 80.19% -1.16%
==========================================
Files 868 868
Lines 96373 96522 +149
Branches 4233 4235 +2
==========================================
- Hits 78399 77410 -989
- Misses 16675 17809 +1134
- Partials 1299 1303 +4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
No description provided.