Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions deepmd/pt_expt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
# as it's a stateless utility class
register_dpmodel_mapping(EnvMat, lambda v: v)

# Register opaque deepmd_export::border_op wrapper (used by GNN MPI
# parallel inference; see comm.py module docstring).
# Register fake tensor implementations for custom tabulate ops
from deepmd.pt_expt.utils import comm # noqa: F401
# Register fake tensor implementations for custom tabulate ops.
# comm.py (border_op fake/autograd) is NOT imported here — its
# ensure_comm_registered() is called lazily from the with_comm_dict
# export path in serialization.py to avoid eager libdeepmd_op_pt.so
# loading that breaks fake-op registration order in tests.
Comment thread
anyangml marked this conversation as resolved.
from deepmd.pt_expt.utils import tabulate_ops # noqa: F401

__all__ = [
Expand Down
46 changes: 36 additions & 10 deletions deepmd/pt_expt/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

import torch

_registered: bool = False


def _check_underlying_ops_loaded() -> None:
"""Surface a clearer error when libdeepmd_op_pt.so isn't loaded.
Expand Down Expand Up @@ -76,15 +78,11 @@
)


_check_underlying_ops_loaded()


# ---------------------------------------------------------------------------
# Fake (meta) impls — let make_fx / torch.export trace through.
# ---------------------------------------------------------------------------


@torch.library.register_fake("deepmd_export::border_op")
def _border_op_fake(
sendlist: torch.Tensor,
sendproc: torch.Tensor,
Expand All @@ -99,7 +97,6 @@
return torch.empty_like(g1)


@torch.library.register_fake("deepmd_export::border_op_backward")
def _border_op_backward_fake(
sendlist: torch.Tensor,
sendproc: torch.Tensor,
Expand Down Expand Up @@ -180,8 +177,37 @@
)


torch.library.register_autograd(
"deepmd_export::border_op",
_border_op_backward,
setup_context=_border_op_setup_context,
)
def ensure_comm_registered() -> None:
"""Load libdeepmd_op_pt.so and register fake/autograd metadata for border_op.

Idempotent — safe to call multiple times. Must be called before any
``make_fx`` / ``torch.export`` trace that passes through border_op (i.e.
before the ``with_comm_dict=True`` export path in serialization.py).

Kept lazy (not called at import time) so that merely importing
``deepmd.pt_expt.utils`` does not force-load libdeepmd_op_pt.so and
disrupt fake-op registration order in tests that don't exercise the comm
path at all.
Comment thread
anyangml marked this conversation as resolved.
"""
global _registered
if _registered:
return
_check_underlying_ops_loaded()
try:
torch.library.register_fake("deepmd_export::border_op")(_border_op_fake)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
try:
torch.library.register_fake("deepmd_export::border_op_backward")(
_border_op_backward_fake
)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
torch.library.register_autograd(
"deepmd_export::border_op",
_border_op_backward,
setup_context=_border_op_setup_context,
)
Comment thread
anyangml marked this conversation as resolved.
_registered = True

Check notice

Code scanning / CodeQL

Unused global variable Note

The global variable '_registered' is not used.
Comment thread
anyangml marked this conversation as resolved.
Comment thread
anyangml marked this conversation as resolved.
Dismissed
8 changes: 8 additions & 0 deletions deepmd/pt_expt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,14 @@ def _trace_and_export(
# matter for tracing — only that they're valid tensors of the right
# shape and dtype. See ``_make_comm_sample_inputs``.
if with_comm_dict:
# Load libdeepmd_op_pt.so and register border_op fake/autograd
# metadata now — deferred from import time so normal utils imports
# don't force-load the op library and break fake-op ordering.
from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)

ensure_comm_registered()
if not _needs_with_comm_artifact(model):
Comment thread
anyangml marked this conversation as resolved.
raise ValueError(
Comment thread
anyangml marked this conversation as resolved.
"with_comm_dict=True requested but the model's descriptor "
Expand Down
7 changes: 4 additions & 3 deletions source/tests/pt_expt/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
_get_current_function_mode_stack,
)

# ``deepmd.pt_expt.utils.comm`` self-bootstraps libdeepmd_op_pt.so via
# ``_check_underlying_ops_loaded()``, so we no longer need to preload
# ``deepmd.pt`` here.
# ``deepmd.pt_expt.utils.comm`` is now lazy: libdeepmd_op_pt.so is only
# loaded when ``ensure_comm_registered()`` is explicitly called from the
# with_comm_dict export path. Tests that don't exercise that path never
# load the op library, preserving fake-op registration order.


def _pop_device_contexts() -> list:
Expand Down
Loading