Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
geometry-latent kNN distance as a continuous score for downstream
consumers (e.g. AL acquisition) without the boolean thresholding /
warning emission of `OODGuard.check()`.
- Adds configurable normalization to ``physicsnemo.nn.Mlp`` via the
``norm_layer`` parameter (``"batchnorm"``, PyTorch ``"layernorm"``,
TE-only ``"te_layernorm"``, or a user-supplied norm factory). ``use_batchnorm``
remains supported for backward compatibility.

### Changed

Expand Down
88 changes: 82 additions & 6 deletions physicsnemo/nn/module/mlp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,78 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multi-layer perceptron (MLP) module with optional Transformer Engine support."""
r"""Multi-layer perceptron (MLP) module with optional Transformer Engine support."""

import itertools
from collections.abc import Callable

import torch
from torch import nn

from physicsnemo.core.version_check import OptionalImport

from .activations import get_activation
from .layer_norm import TE_AVAILABLE

# Check for Transformer Engine availability
te = OptionalImport("transformer_engine.pytorch")

NormLayerSpec = type[nn.Module] | Callable[[int], nn.Module] | str | None


def _require_te_layernorm() -> None:
"""Raise if Transformer Engine LayerNorm cannot be used."""
if not TE_AVAILABLE:
raise RuntimeError(
"norm_layer='te_layernorm' requires transformer_engine to be installed."
)
if not torch.cuda.is_available():
raise RuntimeError("norm_layer='te_layernorm' requires a CUDA device.")


def _resolve_norm_factory(
norm_layer: NormLayerSpec,
use_batchnorm: bool,
) -> Callable[[int], nn.Module] | None:
"""Resolve normalization configuration to a per-layer factory."""
if norm_layer is not None and use_batchnorm:
raise ValueError(
"Cannot specify both norm_layer and use_batchnorm=True. "
"Use norm_layer='batchnorm' or norm_layer=nn.BatchNorm1d instead."
)

if norm_layer is not None:
if isinstance(norm_layer, str):
key = norm_layer.lower().replace("-", "_")
if key in {"batchnorm", "batch_norm", "bn"}:
return nn.BatchNorm1d
if key in {"layernorm", "layer_norm", "ln"}:
return nn.LayerNorm
if key in {"te_layernorm", "te_layer_norm"}:
_require_te_layernorm()
return te.LayerNorm
raise ValueError(
f"Unknown norm_layer string {norm_layer!r}. "
"Expected one of 'batchnorm', 'layernorm', or 'te_layernorm'."
)

if isinstance(norm_layer, nn.Module):
raise ValueError(
"norm_layer must be a class or callable factory, not a module instance."
)

return norm_layer

if use_batchnorm:
return nn.BatchNorm1d

return None


class Mlp(nn.Module):
"""Multi-layer perceptron with configurable architecture.
r"""Multi-layer perceptron with configurable architecture.

Supports arbitrary depth, dropout, batch normalization, spectral
Supports arbitrary depth, dropout, configurable normalization, spectral
normalization, bias control, and optional Transformer Engine linear
layers.

Expand Down Expand Up @@ -63,7 +116,18 @@ class Mlp(nn.Module):
Whether to include bias terms in the linear layers. Default is ``True``.
use_batchnorm : bool, optional
If ``True``, applies ``BatchNorm1d`` after each linear layer
(including the output layer). Default is ``False``.
(including the output layer). Default is ``False``. Mutually
exclusive with ``norm_layer``.
norm_layer : type[nn.Module] | Callable[[int], nn.Module] | str | None, optional
Normalization applied after each linear layer. Can be:
- ``None``: no normalization (unless ``use_batchnorm=True``)
- ``str``: ``"batchnorm"`` for ``BatchNorm1d``; ``"layernorm"`` for PyTorch
``LayerNorm``; ``"te_layernorm"`` for Transformer Engine ``LayerNorm``
(requires ``transformer_engine`` and CUDA)
- ``type`` or callable: factory invoked as ``norm_layer(out_features)``
(for example ``nn.LayerNorm`` or ``get_layer_norm_class()`` for TE-aware
auto selection)
Default is ``None``.
spectral_norm : bool, optional
If ``True``, applies spectral normalization to all linear layer
weights, constraining the spectral norm to 1. Default is ``False``.
Expand Down Expand Up @@ -96,6 +160,15 @@ class Mlp(nn.Module):
... )
>>> mlp(torch.randn(8, 10)).shape
torch.Size([8, 4])

>>> mlp = Mlp(
... in_features=10,
... hidden_features=20,
... out_features=5,
... norm_layer="layernorm",
... )
>>> mlp(torch.randn(4, 10)).shape
torch.Size([4, 5])
"""

def __init__(
Expand All @@ -108,12 +181,15 @@ def __init__(
final_dropout: bool = True,
bias: bool = True,
use_batchnorm: bool = False,
norm_layer: NormLayerSpec = None,
spectral_norm: bool = False,
use_te: bool = False,
):
super().__init__()

self.use_te = use_te
self.norm_layer = norm_layer
norm_factory = _resolve_norm_factory(norm_layer, use_batchnorm)
Comment thread
Thabhelo marked this conversation as resolved.

out_features = out_features or in_features

Expand Down Expand Up @@ -147,8 +223,8 @@ def __init__(
if spectral_norm:
linear = nn.utils.parametrizations.spectral_norm(linear, name="weight")
layers.append(linear)
if use_batchnorm:
layers.append(nn.BatchNorm1d(out_dim))
if norm_factory is not None:
layers.append(norm_factory(out_dim))
if not is_last:
layers.append(act_layer)
if drop != 0 and (not is_last or final_dropout):
Expand Down
181 changes: 181 additions & 0 deletions test/nn/module/test_mlp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from physicsnemo.nn import Mlp
from physicsnemo.nn.module.layer_norm import get_layer_norm_class
from test.common import (
validate_forward_accuracy,
)
Expand Down Expand Up @@ -177,6 +178,186 @@ def test_mlp_batchnorm(device):
assert output.shape == torch.Size([4, 5])


def test_mlp_norm_layer_batchnorm_string(device):
"""Test that norm_layer='batchnorm' matches use_batchnorm=True."""
target_device = torch.device(device)
model = Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer="batchnorm",
).to(target_device)

bn_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.BatchNorm1d))
assert bn_count == 2

output = model(torch.randn(4, 10, device=target_device))
assert output.shape == torch.Size([4, 5])


def test_mlp_norm_layer_layernorm(device):
"""Test that norm_layer='layernorm' inserts PyTorch LayerNorm layers."""
target_device = torch.device(device)
model = Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer="layernorm",
).to(target_device)

ln_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.LayerNorm))
assert ln_count == 2

output = model(torch.randn(4, 10, device=target_device))
assert output.shape == torch.Size([4, 5])


def test_mlp_stores_norm_layer():
"""Test that the norm_layer argument is stored on the module."""
model = Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer="layernorm",
)
assert model.norm_layer == "layernorm"


def test_mlp_norm_layer_callable(device):
"""Test that a custom norm factory can be supplied."""
target_device = torch.device(device)

def norm_factory(out_dim: int) -> torch.nn.Module:
return torch.nn.LayerNorm(out_dim, elementwise_affine=False)

model = Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer=norm_factory,
).to(target_device)

ln_count = sum(
1
for m in model.modules()
if isinstance(m, torch.nn.LayerNorm) and not m.elementwise_affine
)
assert ln_count == 2

output = model(torch.randn(2, 10, device=target_device))
assert output.shape == torch.Size([2, 5])


def test_mlp_norm_layer_type(device):
"""Test that norm_layer=nn.LayerNorm works."""
target_device = torch.device(device)
model = Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer=torch.nn.LayerNorm,
).to(target_device)

ln_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.LayerNorm))
assert ln_count == 2

output = model(torch.randn(2, 10, device=target_device))
assert output.shape == torch.Size([2, 5])


def test_mlp_norm_layer_te_layernorm_class(device):
"""Test that get_layer_norm_class() works as a user-supplied norm factory."""
target_device = torch.device(device)
model = Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer=get_layer_norm_class(),
).to(target_device)

ln_count = sum(
1
for m in model.modules()
if m.__class__.__name__ == "LayerNorm" or isinstance(m, torch.nn.LayerNorm)
)
assert ln_count == 2

output = model(torch.randn(2, 10, device=target_device))
assert output.shape == torch.Size([2, 5])


@requires_module(["transformer_engine"])
def test_mlp_te_layernorm_string(device):
"""Test norm_layer='te_layernorm' requires Transformer Engine and CUDA."""
import importlib.util

if "cuda" not in device:
with pytest.raises(RuntimeError, match="te_layernorm"):
Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer="te_layernorm",
)
return

te_available = importlib.util.find_spec("transformer_engine") is not None
if not te_available:
with pytest.raises(RuntimeError, match="transformer_engine"):
Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer="te_layernorm",
)
return

target_device = torch.device(device)
model = Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer="te_layernorm",
).to(target_device)
assert model.norm_layer == "te_layernorm"
output = model(torch.randn(2, 10, device=target_device))
assert output.shape == torch.Size([2, 5])


def test_mlp_norm_layer_conflicts_with_use_batchnorm():
"""Test that norm_layer and use_batchnorm cannot both be set."""
with pytest.raises(ValueError, match="Cannot specify both norm_layer"):
Mlp(
in_features=10,
hidden_features=20,
out_features=5,
use_batchnorm=True,
norm_layer="layernorm",
)


def test_mlp_norm_layer_rejects_module_instance():
"""Test that a pre-instantiated norm module is rejected."""
with pytest.raises(ValueError, match="not a module instance"):
Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer=torch.nn.LayerNorm(5),
)


def test_mlp_unknown_norm_layer_string():
"""Test that an unknown norm_layer string raises ValueError."""
with pytest.raises(ValueError, match="Unknown norm_layer string"):
Mlp(
in_features=10,
hidden_features=20,
out_features=5,
norm_layer="groupnorm",
)


def test_mlp_spectral_norm(device):
"""Test that spectral_norm wraps linear layers with spectral normalization."""
target_device = torch.device(device)
Expand Down