Skip to content

Commit f0cb8a9

Browse files
Align NAT functionals with FunctionSpec
1 parent ed855da commit f0cb8a9

17 files changed

Lines changed: 638 additions & 232 deletions

File tree

CODING_STANDARDS/FUNCTIONAL_APIS.md

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ This document is structured in two main sections:
5353

5454
| Rule ID | Summary | Apply When |
5555
|---------|---------|------------|
56-
| [`FNC-000`](#fnc-000-functionals-must-use-functionspec) | Functionals must use FunctionSpec | Creating new functional APIs |
56+
| [`FNC-000`](#fnc-000-functionals-must-use-functionspec) | Functionals must use FunctionSpec unless they are lightweight tensor helpers | Creating new functional APIs |
5757
| [`FNC-001`](#fnc-001-functional-location-and-public-api) | Functional location and public API | Organizing or exporting functionals |
5858
| [`FNC-002`](#fnc-002-file-layout-for-functionals) | File layout for functionals | Adding or refactoring functional files |
5959
| [`FNC-003`](#fnc-003-registration-and-dispatch-rules) | Registration and dispatch rules | Registering implementations |
@@ -71,15 +71,30 @@ This document is structured in two main sections:
7171

7272
**Description:**
7373

74-
All functionals must be implemented with `FunctionSpec`, even if only a single
75-
implementation exists. This ensures the operation participates in validation
76-
and benchmarking through input generators and `compare_forward` (and
74+
All functionals with backend dispatch, optional accelerated implementations, or
75+
meaningful benchmark coverage must be implemented with `FunctionSpec`, even if
76+
only a single implementation exists. This ensures the operation participates in
77+
validation and benchmarking through input generators and `compare_forward` (and
7778
`compare_backward` where needed).
7879

80+
Small pure-PyTorch tensor helpers may remain plain functions when all of the
81+
following are true:
82+
83+
- The implementation is a thin composition of PyTorch tensor operations.
84+
- There is no optional backend, custom kernel, or dispatch-selection behavior.
85+
- Benchmarking the helper independently would not provide actionable
86+
performance data.
87+
- The function has focused tests or coverage through its owning feature area.
88+
89+
When a helper later grows an alternate backend, optional dependency, or
90+
performance-sensitive implementation, convert it to `FunctionSpec`.
91+
7992
**Rationale:**
8093

8194
`FunctionSpec` provides a consistent structure for backend registration,
82-
selection, benchmarking and verification across the codebase.
95+
selection, benchmarking and verification across the codebase. The lightweight
96+
helper exception avoids adding ceremony to simple tensor algebra that has no
97+
backend-selection or benchmark surface.
8398

8499
**Example:**
85100

benchmarks/physicsnemo/nn/functional/registry.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
"""Registry of FunctionSpec classes to benchmark with ASV."""
1818

1919
from physicsnemo.core.function_spec import FunctionSpec
20+
from physicsnemo.nn.functional.attention.neighborhood_attention import (
21+
NeighborhoodAttention1D,
22+
NeighborhoodAttention2D,
23+
NeighborhoodAttention3D,
24+
)
2025
from physicsnemo.nn.functional.derivatives import (
2126
MeshGreenGaussGradient,
2227
MeshlessFDDerivatives,
@@ -34,7 +39,11 @@
3439
Real,
3540
ViewAsComplex,
3641
)
37-
from physicsnemo.nn.functional.geometry import SignedDistanceField
42+
from physicsnemo.nn.functional.geometry import (
43+
MeshPoissonDiskSample,
44+
MeshToVoxelFraction,
45+
SignedDistanceField,
46+
)
3847
from physicsnemo.nn.functional.interpolation import (
3948
GridToPointInterpolation,
4049
PointToGridInterpolation,
@@ -62,6 +71,8 @@
6271
SpectralGridGradient,
6372
MeshlessFDDerivatives,
6473
# Geometry.
74+
MeshPoissonDiskSample,
75+
MeshToVoxelFraction,
6576
SignedDistanceField,
6677
# Interpolation.
6778
GridToPointInterpolation,
@@ -74,6 +85,10 @@
7485
ViewAsComplex,
7586
Real,
7687
Imag,
88+
# Neighborhood attention.
89+
NeighborhoodAttention1D,
90+
NeighborhoodAttention2D,
91+
NeighborhoodAttention3D,
7792
)
7893

7994
__all__ = ["FUNCTIONAL_SPECS"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
Neighborhood Attention Functionals
2+
==================================
3+
4+
NATTEN 1D
5+
---------
6+
7+
.. autofunction:: physicsnemo.nn.functional.na1d
8+
9+
NATTEN 2D
10+
---------
11+
12+
.. autofunction:: physicsnemo.nn.functional.na2d
13+
14+
NATTEN 3D
15+
---------
16+
17+
.. autofunction:: physicsnemo.nn.functional.na3d

docs/api/physicsnemo.nn.functionals.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ in the documentation for performance comparisons.
2424
nn/functionals/fourier_spectral
2525
nn/functionals/regularization_parameterization
2626
nn/functionals/interpolation
27+
nn/functionals/neighborhood_attention

physicsnemo/domain_parallel/shard_utils/natten_patches.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
MissingShardPatch,
3333
UndeterminedShardingError,
3434
)
35-
from physicsnemo.nn.functional.natten import na1d, na2d, na3d
35+
from physicsnemo.nn.functional.attention.neighborhood_attention import na1d, na2d, na3d
3636

3737
_natten = OptionalImport("natten")
3838
_raw_func_map = {
@@ -221,9 +221,9 @@ def _natten_wrapper(
221221
r"""Shared wrapper for natten functions to support sharded tensors.
222222
223223
Registered with :meth:`ShardTensor.register_function_handler` so that calls
224-
to :func:`~physicsnemo.nn.functional.natten.na1d`,
225-
:func:`~physicsnemo.nn.functional.natten.na2d`, or
226-
:func:`~physicsnemo.nn.functional.natten.na3d` automatically route through
224+
to :func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na1d`,
225+
:func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na2d`, or
226+
:func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na3d` automatically route through
227227
this handler when any argument is a :class:`ShardTensor`.
228228
229229
Parameters
@@ -250,7 +250,16 @@ def _natten_wrapper(
250250
q, k, v, kernel_size = args[0], args[1], args[2], args[3]
251251

252252
dilation = kwargs.get("dilation", 1)
253-
natten_kwargs = {_k: _v for _k, _v in kwargs.items() if _k != "dilation"}
253+
implementation = kwargs.get("implementation")
254+
if implementation not in (None, "natten"):
255+
raise KeyError(
256+
f"No implementation named '{implementation}' for neighborhood attention"
257+
)
258+
natten_kwargs = {
259+
_k: _v
260+
for _k, _v in kwargs.items()
261+
if _k not in ("dilation", "implementation")
262+
}
254263

255264
if all(type(_t) is torch.Tensor for _t in (q, k, v)):
256265
return func(

physicsnemo/experimental/models/globe/field_kernel.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
split_by_leaf_rank,
4141
)
4242
from physicsnemo.nn import Mlp, Pade
43-
from physicsnemo.nn.functional.equivariant_ops import (
43+
from physicsnemo.nn.functional.equivariant.ops import (
4444
legendre_polynomials,
4545
polar_and_dipole_basis,
4646
smooth_log,
@@ -896,11 +896,7 @@ def forward(
896896
TensorDict[str, Float[torch.Tensor, "n_targets ..."]]
897897
Kernel output fields at target points.
898898
"""
899-
from physicsnemo.experimental.models.globe.cluster_tree import (
900-
ClusterTree,
901-
DualInteractionPlan,
902-
SourceAggregates,
903-
)
899+
from physicsnemo.experimental.models.globe.cluster_tree import ClusterTree
904900
from physicsnemo.mesh.spatial._ragged import _ragged_arange
905901

906902
n_sources = source_points.shape[0]

physicsnemo/nn/functional/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from .attention import na1d, na2d, na3d
1718
from .derivatives import (
1819
mesh_green_gauss_gradient,
1920
mesh_lsq_gradient,
@@ -22,7 +23,7 @@
2223
spectral_grid_gradient,
2324
uniform_grid_gradient,
2425
)
25-
from .equivariant_ops import (
26+
from .equivariant import (
2627
legendre_polynomials,
2728
polar_and_dipole_basis,
2829
smooth_log,
@@ -40,7 +41,6 @@
4041
interpolation,
4142
point_to_grid_interpolation,
4243
)
43-
from .natten import na1d, na2d, na3d
4444
from .neighbors import knn, radius_search
4545
from .regularization_parameterization import drop_path, weight_fact
4646

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from .neighborhood_attention import (
18+
NeighborhoodAttention1D,
19+
NeighborhoodAttention2D,
20+
NeighborhoodAttention3D,
21+
na1d,
22+
na2d,
23+
na3d,
24+
)
25+
26+
__all__ = [
27+
"NeighborhoodAttention1D",
28+
"NeighborhoodAttention2D",
29+
"NeighborhoodAttention3D",
30+
"na1d",
31+
"na2d",
32+
"na3d",
33+
]

0 commit comments

Comments
 (0)