From 21964112b377003e1c730e19ba3f9673d8c9a086 Mon Sep 17 00:00:00 2001 From: Thabhelo <50872400+Thabhelo@users.noreply.github.com> Date: Fri, 5 Jun 2026 22:56:08 -0400 Subject: [PATCH 1/2] feat(nn): add configurable norm_layer to Mlp Add a norm_layer parameter supporting LayerNorm, TE-aware norms, and custom factories while keeping use_batchnorm backward compatible. Closes #1451 Signed-off-by: Thabhelo <50872400+Thabhelo@users.noreply.github.com> --- CHANGELOG.md | 4 + physicsnemo/nn/module/mlp_layers.py | 72 +++++++++++++-- test/nn/module/test_mlp_layers.py | 136 ++++++++++++++++++++++++++++ 3 files changed, 206 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b4455f438f..203ac90024 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 geometry-latent kNN distance as a continuous score for downstream consumers (e.g. AL acquisition) without the boolean thresholding / warning emission of `OODGuard.check()`. +- Adds configurable normalization to ``physicsnemo.nn.Mlp`` via the + ``norm_layer`` parameter (``"batchnorm"``, ``"layernorm"``, + ``"te_layernorm"``, or a user-supplied norm factory). ``use_batchnorm`` + remains supported for backward compatibility. ### Changed diff --git a/physicsnemo/nn/module/mlp_layers.py b/physicsnemo/nn/module/mlp_layers.py index 9c0febb663..029bf4c00f 100644 --- a/physicsnemo/nn/module/mlp_layers.py +++ b/physicsnemo/nn/module/mlp_layers.py @@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Multi-layer perceptron (MLP) module with optional Transformer Engine support.""" +r"""Multi-layer perceptron (MLP) module with optional Transformer Engine support.""" import itertools +from collections.abc import Callable import torch from torch import nn @@ -24,15 +25,55 @@ from physicsnemo.core.version_check import OptionalImport from .activations import get_activation +from .layer_norm import get_layer_norm_class # Check for Transformer Engine availability te = OptionalImport("transformer_engine.pytorch") +NormLayerSpec = type[nn.Module] | Callable[[int], nn.Module] | str | None + + +def _resolve_norm_factory( + norm_layer: NormLayerSpec, + use_batchnorm: bool, +) -> Callable[[int], nn.Module] | None: + """Resolve normalization configuration to a per-layer factory.""" + if norm_layer is not None and use_batchnorm: + raise ValueError( + "Cannot specify both norm_layer and use_batchnorm=True. " + "Use norm_layer='batchnorm' or norm_layer=nn.BatchNorm1d instead." + ) + + if norm_layer is not None: + if isinstance(norm_layer, str): + key = norm_layer.lower().replace("-", "_") + if key in {"batchnorm", "batch_norm", "bn"}: + return nn.BatchNorm1d + if key in {"layernorm", "layer_norm", "ln", "te_layernorm"}: + layer_norm = get_layer_norm_class() + return layer_norm + raise ValueError( + f"Unknown norm_layer string {norm_layer!r}. " + "Expected one of 'batchnorm', 'layernorm', or 'te_layernorm'." + ) + + if isinstance(norm_layer, nn.Module): + raise ValueError( + "norm_layer must be a class or callable factory, not a module instance." + ) + + return norm_layer + + if use_batchnorm: + return nn.BatchNorm1d + + return None + class Mlp(nn.Module): - """Multi-layer perceptron with configurable architecture. + r"""Multi-layer perceptron with configurable architecture. - Supports arbitrary depth, dropout, batch normalization, spectral + Supports arbitrary depth, dropout, configurable normalization, spectral normalization, bias control, and optional Transformer Engine linear layers. @@ -63,7 +104,15 @@ class Mlp(nn.Module): Whether to include bias terms in the linear layers. Default is ``True``. use_batchnorm : bool, optional If ``True``, applies ``BatchNorm1d`` after each linear layer - (including the output layer). Default is ``False``. + (including the output layer). Default is ``False``. Mutually + exclusive with ``norm_layer``. + norm_layer : type[nn.Module] | Callable[[int], nn.Module] | str | None, optional + Normalization applied after each linear layer. Can be: + - ``None``: no normalization (unless ``use_batchnorm=True``) + - ``str``: ``"batchnorm"``, ``"layernorm"``, or ``"te_layernorm"`` + - ``type`` or callable: factory invoked as ``norm_layer(out_features)`` + (for example ``nn.LayerNorm`` or ``get_layer_norm_class()``) + Default is ``None``. spectral_norm : bool, optional If ``True``, applies spectral normalization to all linear layer weights, constraining the spectral norm to 1. Default is ``False``. @@ -96,6 +145,15 @@ class Mlp(nn.Module): ... ) >>> mlp(torch.randn(8, 10)).shape torch.Size([8, 4]) + + >>> mlp = Mlp( + ... in_features=10, + ... hidden_features=20, + ... out_features=5, + ... norm_layer="layernorm", + ... ) + >>> mlp(torch.randn(4, 10)).shape + torch.Size([4, 5]) """ def __init__( @@ -108,12 +166,14 @@ def __init__( final_dropout: bool = True, bias: bool = True, use_batchnorm: bool = False, + norm_layer: NormLayerSpec = None, spectral_norm: bool = False, use_te: bool = False, ): super().__init__() self.use_te = use_te + norm_factory = _resolve_norm_factory(norm_layer, use_batchnorm) out_features = out_features or in_features @@ -147,8 +207,8 @@ def __init__( if spectral_norm: linear = nn.utils.parametrizations.spectral_norm(linear, name="weight") layers.append(linear) - if use_batchnorm: - layers.append(nn.BatchNorm1d(out_dim)) + if norm_factory is not None: + layers.append(norm_factory(out_dim)) if not is_last: layers.append(act_layer) if drop != 0 and (not is_last or final_dropout): diff --git a/test/nn/module/test_mlp_layers.py b/test/nn/module/test_mlp_layers.py index 60a24c6393..8bc8a3b918 100644 --- a/test/nn/module/test_mlp_layers.py +++ b/test/nn/module/test_mlp_layers.py @@ -18,6 +18,7 @@ import torch from physicsnemo.nn import Mlp +from physicsnemo.nn.module.layer_norm import get_layer_norm_class from test.common import ( validate_forward_accuracy, ) @@ -177,6 +178,141 @@ def test_mlp_batchnorm(device): assert output.shape == torch.Size([4, 5]) +def test_mlp_norm_layer_batchnorm_string(device): + """Test that norm_layer='batchnorm' matches use_batchnorm=True.""" + target_device = torch.device(device) + model = Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer="batchnorm", + ).to(target_device) + + bn_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.BatchNorm1d)) + assert bn_count == 2 + + output = model(torch.randn(4, 10, device=target_device)) + assert output.shape == torch.Size([4, 5]) + + +def test_mlp_norm_layer_layernorm(device): + """Test that norm_layer='layernorm' inserts layer norm after each linear.""" + target_device = torch.device(device) + model = Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer="layernorm", + ).to(target_device) + + ln_count = sum( + 1 + for m in model.modules() + if m.__class__.__name__ == "LayerNorm" or isinstance(m, torch.nn.LayerNorm) + ) + assert ln_count == 2 + + output = model(torch.randn(4, 10, device=target_device)) + assert output.shape == torch.Size([4, 5]) + + +def test_mlp_norm_layer_callable(device): + """Test that a custom norm factory can be supplied.""" + target_device = torch.device(device) + + def norm_factory(out_dim: int) -> torch.nn.Module: + return torch.nn.LayerNorm(out_dim, elementwise_affine=False) + + model = Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer=norm_factory, + ).to(target_device) + + ln_count = sum( + 1 + for m in model.modules() + if isinstance(m, torch.nn.LayerNorm) and not m.elementwise_affine + ) + assert ln_count == 2 + + output = model(torch.randn(2, 10, device=target_device)) + assert output.shape == torch.Size([2, 5]) + + +def test_mlp_norm_layer_type(device): + """Test that norm_layer=nn.LayerNorm works.""" + target_device = torch.device(device) + model = Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer=torch.nn.LayerNorm, + ).to(target_device) + + ln_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.LayerNorm)) + assert ln_count == 2 + + output = model(torch.randn(2, 10, device=target_device)) + assert output.shape == torch.Size([2, 5]) + + +def test_mlp_norm_layer_te_layernorm_class(device): + """Test that get_layer_norm_class() works as a user-supplied norm factory.""" + target_device = torch.device(device) + model = Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer=get_layer_norm_class(), + ).to(target_device) + + ln_count = sum( + 1 + for m in model.modules() + if m.__class__.__name__ == "LayerNorm" or isinstance(m, torch.nn.LayerNorm) + ) + assert ln_count == 2 + + output = model(torch.randn(2, 10, device=target_device)) + assert output.shape == torch.Size([2, 5]) + + +def test_mlp_norm_layer_conflicts_with_use_batchnorm(): + """Test that norm_layer and use_batchnorm cannot both be set.""" + with pytest.raises(ValueError, match="Cannot specify both norm_layer"): + Mlp( + in_features=10, + hidden_features=20, + out_features=5, + use_batchnorm=True, + norm_layer="layernorm", + ) + + +def test_mlp_norm_layer_rejects_module_instance(): + """Test that a pre-instantiated norm module is rejected.""" + with pytest.raises(ValueError, match="not a module instance"): + Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer=torch.nn.LayerNorm(5), + ) + + +def test_mlp_unknown_norm_layer_string(): + """Test that an unknown norm_layer string raises ValueError.""" + with pytest.raises(ValueError, match="Unknown norm_layer string"): + Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer="groupnorm", + ) + + def test_mlp_spectral_norm(device): """Test that spectral_norm wraps linear layers with spectral normalization.""" target_device = torch.device(device) From 8896888ab262968ba58d0c2ad76ea9345dcedce0 Mon Sep 17 00:00:00 2001 From: Thabhelo <50872400+Thabhelo@users.noreply.github.com> Date: Fri, 5 Jun 2026 23:04:16 -0400 Subject: [PATCH 2/2] fix(nn): address Greptile review on Mlp norm_layer Split layernorm (PyTorch LayerNorm) from te_layernorm (requires TE and CUDA), store norm_layer on the module, and extend tests accordingly. Signed-off-by: Thabhelo <50872400+Thabhelo@users.noreply.github.com> --- CHANGELOG.md | 4 +- physicsnemo/nn/module/mlp_layers.py | 28 +++++++++++--- test/nn/module/test_mlp_layers.py | 57 ++++++++++++++++++++++++++--- 3 files changed, 75 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 203ac90024..ca954c8ba6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,8 +45,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 consumers (e.g. AL acquisition) without the boolean thresholding / warning emission of `OODGuard.check()`. - Adds configurable normalization to ``physicsnemo.nn.Mlp`` via the - ``norm_layer`` parameter (``"batchnorm"``, ``"layernorm"``, - ``"te_layernorm"``, or a user-supplied norm factory). ``use_batchnorm`` + ``norm_layer`` parameter (``"batchnorm"``, PyTorch ``"layernorm"``, + TE-only ``"te_layernorm"``, or a user-supplied norm factory). ``use_batchnorm`` remains supported for backward compatibility. ### Changed diff --git a/physicsnemo/nn/module/mlp_layers.py b/physicsnemo/nn/module/mlp_layers.py index 029bf4c00f..ccc8946bf0 100644 --- a/physicsnemo/nn/module/mlp_layers.py +++ b/physicsnemo/nn/module/mlp_layers.py @@ -25,7 +25,7 @@ from physicsnemo.core.version_check import OptionalImport from .activations import get_activation -from .layer_norm import get_layer_norm_class +from .layer_norm import TE_AVAILABLE # Check for Transformer Engine availability te = OptionalImport("transformer_engine.pytorch") @@ -33,6 +33,16 @@ NormLayerSpec = type[nn.Module] | Callable[[int], nn.Module] | str | None +def _require_te_layernorm() -> None: + """Raise if Transformer Engine LayerNorm cannot be used.""" + if not TE_AVAILABLE: + raise RuntimeError( + "norm_layer='te_layernorm' requires transformer_engine to be installed." + ) + if not torch.cuda.is_available(): + raise RuntimeError("norm_layer='te_layernorm' requires a CUDA device.") + + def _resolve_norm_factory( norm_layer: NormLayerSpec, use_batchnorm: bool, @@ -49,9 +59,11 @@ def _resolve_norm_factory( key = norm_layer.lower().replace("-", "_") if key in {"batchnorm", "batch_norm", "bn"}: return nn.BatchNorm1d - if key in {"layernorm", "layer_norm", "ln", "te_layernorm"}: - layer_norm = get_layer_norm_class() - return layer_norm + if key in {"layernorm", "layer_norm", "ln"}: + return nn.LayerNorm + if key in {"te_layernorm", "te_layer_norm"}: + _require_te_layernorm() + return te.LayerNorm raise ValueError( f"Unknown norm_layer string {norm_layer!r}. " "Expected one of 'batchnorm', 'layernorm', or 'te_layernorm'." @@ -109,9 +121,12 @@ class Mlp(nn.Module): norm_layer : type[nn.Module] | Callable[[int], nn.Module] | str | None, optional Normalization applied after each linear layer. Can be: - ``None``: no normalization (unless ``use_batchnorm=True``) - - ``str``: ``"batchnorm"``, ``"layernorm"``, or ``"te_layernorm"`` + - ``str``: ``"batchnorm"`` for ``BatchNorm1d``; ``"layernorm"`` for PyTorch + ``LayerNorm``; ``"te_layernorm"`` for Transformer Engine ``LayerNorm`` + (requires ``transformer_engine`` and CUDA) - ``type`` or callable: factory invoked as ``norm_layer(out_features)`` - (for example ``nn.LayerNorm`` or ``get_layer_norm_class()``) + (for example ``nn.LayerNorm`` or ``get_layer_norm_class()`` for TE-aware + auto selection) Default is ``None``. spectral_norm : bool, optional If ``True``, applies spectral normalization to all linear layer @@ -173,6 +188,7 @@ def __init__( super().__init__() self.use_te = use_te + self.norm_layer = norm_layer norm_factory = _resolve_norm_factory(norm_layer, use_batchnorm) out_features = out_features or in_features diff --git a/test/nn/module/test_mlp_layers.py b/test/nn/module/test_mlp_layers.py index 8bc8a3b918..8bc24dd51a 100644 --- a/test/nn/module/test_mlp_layers.py +++ b/test/nn/module/test_mlp_layers.py @@ -196,7 +196,7 @@ def test_mlp_norm_layer_batchnorm_string(device): def test_mlp_norm_layer_layernorm(device): - """Test that norm_layer='layernorm' inserts layer norm after each linear.""" + """Test that norm_layer='layernorm' inserts PyTorch LayerNorm layers.""" target_device = torch.device(device) model = Mlp( in_features=10, @@ -205,17 +205,24 @@ def test_mlp_norm_layer_layernorm(device): norm_layer="layernorm", ).to(target_device) - ln_count = sum( - 1 - for m in model.modules() - if m.__class__.__name__ == "LayerNorm" or isinstance(m, torch.nn.LayerNorm) - ) + ln_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.LayerNorm)) assert ln_count == 2 output = model(torch.randn(4, 10, device=target_device)) assert output.shape == torch.Size([4, 5]) +def test_mlp_stores_norm_layer(): + """Test that the norm_layer argument is stored on the module.""" + model = Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer="layernorm", + ) + assert model.norm_layer == "layernorm" + + def test_mlp_norm_layer_callable(device): """Test that a custom norm factory can be supplied.""" target_device = torch.device(device) @@ -279,6 +286,44 @@ def test_mlp_norm_layer_te_layernorm_class(device): assert output.shape == torch.Size([2, 5]) +@requires_module(["transformer_engine"]) +def test_mlp_te_layernorm_string(device): + """Test norm_layer='te_layernorm' requires Transformer Engine and CUDA.""" + import importlib.util + + if "cuda" not in device: + with pytest.raises(RuntimeError, match="te_layernorm"): + Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer="te_layernorm", + ) + return + + te_available = importlib.util.find_spec("transformer_engine") is not None + if not te_available: + with pytest.raises(RuntimeError, match="transformer_engine"): + Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer="te_layernorm", + ) + return + + target_device = torch.device(device) + model = Mlp( + in_features=10, + hidden_features=20, + out_features=5, + norm_layer="te_layernorm", + ).to(target_device) + assert model.norm_layer == "te_layernorm" + output = model(torch.randn(2, 10, device=target_device)) + assert output.shape == torch.Size([2, 5]) + + def test_mlp_norm_layer_conflicts_with_use_batchnorm(): """Test that norm_layer and use_batchnorm cannot both be set.""" with pytest.raises(ValueError, match="Cannot specify both norm_layer"):