Skip to content

Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690

Open
fgiral000 wants to merge 52 commits into
NVIDIA:mainfrom
fgiral000:aerojepa-integration
Open

Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690
fgiral000 wants to merge 52 commits into
NVIDIA:mainfrom
fgiral000:aerojepa-integration

Conversation

@fgiral000

Copy link
Copy Markdown

PhysicsNeMo Pull Request

Description

Adds the AeroJEPA model and a SuperWing tutorial recipe under
physicsnemo.experimental and examples/cfd/external_aerodynamics/.
AeroJEPA is a Joint-Embedding Predictive Architecture for 3D
aerodynamic surrogate modeling: instead of mapping geometry directly to
a flow field, it predicts a latent representation of the flow from a
latent representation of the geometry and operating conditions, and
reconstructs the field through a continuous implicit decoder when
needed (Giral et al., arXiv:2605.05586).

What this PR delivers:

  • Model at physicsnemo.experimental.models.aerojepa.
    AeroJEPA composes a context encoder, a target encoder, a query-token
    field decoder (collectively AeroJEPATrunk), and a JEPA predictor
    head (PrototypeTokenJEPAHead) into a single
    physicsnemo.core.module.Module. The training path takes context
    positions/features, independent target encoder surface/volume inputs,
    and operating conditions; the predictor predicts target tokens, and
    the decoder evaluates the field at user-supplied query points.
    predict is a no-grad inference wrapper; decode_field_chunked
    supports memory-bounded evaluation over very large query sets.
    Concrete encoders (ContextTransformer, TargetTransformer,
    PointTransformer), the QueryTokenDecoder, and the encoder ABCs
    are all exposed as composable components.
  • Building blocks at
    physicsnemo.experimental.models.aerojepa.layers. TokenSet and
    EncoderOutput token dataclasses, a deterministic
    FourierPositionalEncoding, ResidualMLP, the
    LocalPointTransformerBlock / LocalTokenCrossAttentionBlock
    attention blocks (with optional AdaLN / AdaLN-Zero conditioning), the
    PointCloudTokenizer (seven center-selection strategies with k-NN
    cluster pooling), token batching / mask / k-NN helpers, and prototype
    anchor build / load utilities. TokenSet and EncoderOutput are
    re-exported from the model package for convenience.
  • Losses at physicsnemo.experimental.models.aerojepa.losses.
    SIGReg and TokenLatentSIGReg (a sketch isotropic-Gaussian
    regularizer for latent-token distributions, with a padding-aware
    wrapper), the flatten_valid_token_features /
    reshape_token_features_for_sigreg masking helpers, and the
    reconstruction loss family (MSELoss / RelativeL2Loss /
    RelativeMSELoss / RelativeL2MSELoss, each with functional and
    nn.Module forms, optional per-channel weights stored as a
    persistent buffer, optional per-point weights, and an optional
    validity mask).
  • Tutorial recipe at
    examples/cfd/external_aerodynamics/aerojepa. End-to-end Hydra-driven
    workflow on the public SuperWing dataset (Yang et al.,
    arXiv:2512.14397): dataset download via the Hugging Face Hub
    (yunplus/SuperWing), automatic split-by-geometry manifest and
    per-channel normalization stats, JEPA training (reconstruction +
    latent + SIGReg with linear warmups; AdamW +
    warmup-cosine; optional EMA), checkpointed inference with chunked
    decoding, three-panel GT | Pred | |Error| field plots for the three
    surface channels (Cp, Cf_tau, Cf_z), per-channel relative-L2 /
    RMSE / MAE metrics on the test split, and a pressure-only CL/CD
    post-processor that integrates the surface field and emits a per-case
    CSV plus a parity scatter.

Checklist

Tests

  • 193 unit tests under test/experimental/models/aerojepa/
    (constructor + attribute checks, non-regression shape checks on the
    encoders, decoder, predictor, trunk, top-level model, layers, and
    losses). pytest test/experimental/models/aerojepa/ -q passes
    locally on CPU (~20 s).
  • Full SuperWing end-to-end smoke-tested on a single GPU:
    train.py -> inference.py -> superwing_metrics -> superwing_forces.
    Training losses decrease monotonically; inference produces field
    plots, per-case field-error metrics, and a force-coefficient parity
    scatter.

Dependencies

No new core dependencies. The example recipe adds optional
example-side dependencies in
examples/cfd/external_aerodynamics/aerojepa/requirements.txt
(Hugging Face Hub for the dataset download, plotting and
post-processing utilities). Pre-commit hooks, ruff, interrogate,
markdownlint, and the SPDX license check pass on every file in the
PR.

fgiral000 added 30 commits June 1, 2026 13:15
Create an empty subpackage for the AeroJEPA reusable building blocks
(attention blocks, geometry tokenizer, context/target encoders, decoder,
predictor) that land in subsequent commits. Establishes the SPDX
license header, module docstring, and ``__all__`` placeholder so that
follow-up commits only need to register new public symbols.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
TokenSet bundles token features with their geometric coordinates and
optional mask, global token, and auxiliary side data; EncoderOutput is
a thin wrapper used by context and target encoders to surface a global
summary alongside the per-token output. Includes raw-string docstrings
with Parameters/Examples sections (three executable doctests), modern
union syntax, and the SPDX header.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A deterministic log-frequency sinusoidal positional encoding used to
lift continuous query coordinates into a high-dimensional feature
space before the implicit decoder consumes them. Distinct from
physicsnemo.nn.FourierEmbedding (random Gaussian frequencies on
scalar timesteps); this variant uses fixed log-powers of pi on
multi-dim coordinates with the standard sin/cos band layout.
Includes an out_dim property, jaxtyping on forward, and an
executable doctest.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A private _gpu_knn module bundling chunked torch.cdist plus topk for
building homogeneous (gpu_knn_self) and bipartite (gpu_knn_bipartite)
k-NN graphs and inverse-distance interpolation (gpu_knn_interpolate).
Pure PyTorch, no warp or custom CUDA — works on CPU too, just slower.
The leading underscore on the filename makes the module package-
private; callers live inside the aerojepa subpackage only.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A token_utils module with the helpers used by the AeroJEPA tokenizer,
encoders and attention blocks: gather_rows, counts_to_mask,
flatten_padded_batch / unflatten_to_padded, compute_batch_offset_step,
flatten_batched_coords, chunked_knn_indices (CPU/GPU dispatcher with
the AE_KNN_BACKEND env override), masked_mean, trim_batched_tokens,
and pad_token_sets. Behavior preserved; types modernized and the
TokenSet import is package-relative.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A reusable trio of attention building blocks: ResidualMLP (pre-norm
residual MLP with optional AdaLN / AdaLN-Zero conditioning),
LocalPointTransformerBlock (local self-attention over a per-point
k-NN graph with learned relative-position bias), and
LocalTokenCrossAttentionBlock (cross-attention from queries to a
per-query k-NN of context tokens, with a 5-way conditioning MLP that
modulates query and key/value sides independently). Behavior
preserved: zero-init conditioning MLPs give an identity transform at
construction time, and the N<=1 / empty-input fallbacks short-circuit
the same way they do upstream.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A tokenizer module that reduces a raw point set to a bounded token
budget before attention. Seven strategies: identity, random, FPS,
random/FPS/voxel-FPS cluster pooling, and prototype-anchored
clustering. The cluster strategies return the kNN indices that link
each token center back to the source points, allowing a downstream
encoder to replace the default feature mean with a learned pooling
(e.g. the message-passing PointClusterGraphPool that lands with the
encoders in PR NVIDIA#3). Behavior preserved including the non-persistent
prototype_coords buffer and the per-sample loop used by the
prototype strategy.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Build the fixed k-means anchor set used by the
data_prototype_cluster tokenizer strategy. The build pass walks a
training dataset, tokenizes each sample to obtain candidate token
coordinates, optionally subsamples, runs chunked Lloyd-iteration
k-means with empty-cluster FPS refill, sorts the centers
lexicographically, and serializes them with a JSON metadata blob.
Two load functions (target / context - identical file layout) and
two ensure_* helpers (load-if-exists else build) round out the
public surface. Behavior preserved; the seed argument governs
k-means initialization and candidate subsampling but not the
tokenizer pass, which intentionally uses random sampling.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Export the 23 public symbols from the six source modules at the
package level: TokenSet/EncoderOutput dataclasses,
FourierPositionalEncoding, ResidualMLP and the two local attention
blocks, PointCloudTokenizer, the ten batching/mask/kNN helpers,
and the six prototype anchor build/load functions. Module
docstring tightened to reflect the actual contents (encoders /
decoder / predictor land in physicsnemo.experimental.models.aerojepa
in a later PR). The package-private _gpu_knn helpers remain
accessible via their submodule path but are intentionally not
re-exported.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Re-export the five AeroJEPA nn.Module layer classes
(FourierPositionalEncoding, ResidualMLP, LocalPointTransformerBlock,
LocalTokenCrossAttentionBlock, PointCloudTokenizer) at the
experimental.nn parent namespace, alongside the existing FLARE
and DiffusionUNet3D family. Data types (TokenSet, EncoderOutput),
batching/mask helpers, and prototype-anchor builders stay scoped
to the aerojepa subpackage to keep the parent namespace focused
on actual layers.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Tests for TokenSet and EncoderOutput covering construction (both
batched and unbatched), the is_batched / token_dim properties, the
with_updates immutability + selective-replacement contract, and the
independence of the default aux dict across instances. Uses the
shared device fixture so the CUDA path runs when available.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…locks

Six new test files covering positional_encoding, attention_blocks,
point_tokenizer, token_utils, _gpu_knn, and prototype_anchors. 85 new
tests covering: constructor validation paths, forward output shapes,
edge cases (N<=1 LPT fallback, empty cross-attention, empty/single-
point kNN, missing voxel_size, non-persistent prototype_coords
buffer), identity-at-init of AdaLN-Zero conditioning MLPs, the
AE_KNN_BACKEND env override, and build/load round-trips with a tiny
fake dataset. All tests use the shared device fixture so CUDA runs
when available; CPU run is 18 s wall.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Documents the new physicsnemo.experimental.nn.aerojepa subpackage
contributed across the preceding 12 commits on this branch: token
dataclasses, Fourier positional encoding, ResidualMLP, the two local
attention blocks, PointCloudTokenizer, token batching/mask/kNN
helpers, and prototype anchor utilities, plus the parent-namespace
re-export of the five layer classes.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Create an empty subpackage for JEPA-style losses and regularizers
(SIGReg, TokenLatentSIGReg, the padding-aware masking helpers, and
the reconstruction loss family) that land in subsequent commits.
Establishes the SPDX license header, module docstring, and
``__all__`` placeholder so that follow-up commits only need to
register new public symbols.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Two utilities used by SIGReg / TokenLatentSIGReg to flatten padded
batched token features and reshape them into the (T, B, D) layout
SIGReg expects. flatten_valid_token_features is a passthrough on
rank-2 inputs and uses boolean masking on rank-3 inputs;
reshape_token_features_for_sigreg adds the leading T=1 axis and
emits a zero-element (1, 0, D) placeholder when the mask removes
every row.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
SIGReg pushes a learned latent toward N(0, I) by comparing the
empirical Fourier characteristic function of random projections
against the reference Gaussian one on a uniform knot grid (the
LeWorldModel construction). Three non-learnable buffers cache the
knot positions, the reference window, and the trapezoidal +
window-weighted integration weights. TokenLatentSIGReg is a thin
wrapper that accepts (B, N, D) or (N, D) features plus an optional
mask, drops padded rows via the masking helpers, and short-circuits
to a zero scalar when the mask removes every row.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Four loss families exposed as functional and nn.Module variants:
mse_loss / MSELoss (channel-weighted MSE with mask + point weights),
relative_l2_loss / RelativeL2Loss (per-channel relative L2 averaged
over channels), relative_mse_loss / RelativeMSELoss (relative MSE
with selectable pointwise vs channel_max normalization), and the
relative_l2_mse_loss / RelativeL2MSELoss hybrid that linearly
combines the L2 and MSE terms. Channel weights are stored as a
persistent float32 buffer on the Module variants when supplied, and
as a non-persistent None buffer otherwise.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Move the JEPA losses subpackage from
physicsnemo.experimental.metrics.jepa to .metrics.aerojepa so it
mirrors the nn.aerojepa naming. Populate the package __init__ with
the 12 public re-exports from masking, sigreg, and reconstruction
(flatten/reshape token helpers, SIGReg/TokenLatentSIGReg, and the
four reconstruction loss families both functional and as nn.Module).

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Three test files mirroring the source modules: test_masking,
test_sigreg, test_reconstruction. 37 tests covering constructor
validation, forward shape, edge cases (rank-1, empty batch, all-False
mask), the SIGReg buffer layout, state_dict persistence of
channel_weights on the reconstruction Module variants, both modes of
relative_mse_loss, and the hybrid degenerating to either of its two
sub-losses when the corresponding weight is zero. CPU run is 4 s wall;
device fixture picks up CUDA automatically.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Documents the new physicsnemo.experimental.metrics.aerojepa
subpackage contributed across the preceding 6 commits on this
branch: SIGReg / TokenLatentSIGReg regularizers, masking helpers,
and the four reconstruction loss families.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Create an empty subpackage for the top-level AeroJEPA model and its
model-specific subcomponents (context/target/point encoders, decoder,
predictor, trunk) that land in subsequent commits. Module docstring
points readers to experimental.nn.aerojepa for the reusable building
blocks the model is composed from.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Abstract base classes BaseContextEncoder and BaseTargetEncoder
(plus the encoders subpackage init) define the contract concrete
encoders must satisfy: a required forward returning an EncoderOutput
and an optional forward_batched gated by a supports_batched_forward
class flag. The context encoder's forward args are named context_pos
/ context_feat (these bundle the boundary and any volumetric samples
in whole-domain models; the SDF channel in context_feat distinguishes
the two halves at inference). The target encoder keeps the surface
/ volume split because training-time subsamplings for the two are
intentionally decoupled.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Context tokens are produced from geometry alone - operating
conditions enter the model downstream at the predictor head, not
at the context branch. Remove gen_params from BaseContextEncoder
forward and forward_batched signatures. Class docstring spells
out the intent. BaseTargetEncoder is untouched.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
PointTransformer (point.py) is a point-cloud encoder building block:
tokenizes the input via PointCloudTokenizer, embeds tokens with a
Fourier positional encoding plus per-feature linear projection,
optionally adds a conditioning vector, runs a stack of
LocalPointTransformerBlock layers with configurable dilation, and
emits an EncoderOutput. Two entry points - encode_single for
unbatched inputs and forward_batched for padded batches with
per-batch coordinate offsetting so the inner k-NN does not mix
tokens across batch items. The same file carries the
build_geometry_features helper (assembles per-point features from
positions and optional SDF / normals / n-dot channels) and the
message-passing PointClusterGraphPool used when
tokenizer_cluster_pooling='graph'.

ContextTransformer (context.py) is the concrete BaseContextEncoder.
Takes context_pos and context_feat - no gen_params, since operating
conditions enter the model downstream at the predictor head.
Internally wraps PointTransformer with conditioning disabled.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Mirror the context-side change: target encoders take their inputs
straight, with no gen_params threaded through. Operating conditions
enter the model only at the predictor head.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
JEPA target encoders are self-attention only. Remove context_tokens
from forward and forward_batched.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Concrete BaseTargetEncoder that wraps an inner PointTransformer.
Forward concatenates surface and volume into one bundled point set;
forward_batched weaves variable-length surface and volume halves per
batch via counts_to_mask. Self-attention only.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Implicit field decoder driven by cross-attention to target tokens.
Per-query embedding is a Fourier positional encoding plus optional
SDF channel and optional cond vector; cross-attention to the target
token set refines it, a trunk MLP and head produce the output.
Several optional behaviors wire in: wall-velocity gate, pressure
split head (MLP or SIREN), final SIREN refinement, extra SDF
features. Both forward (single) and forward_batched (padded)
process queries in chunks of query_chunk_size and return
(pred, query_embeddings). SineLayer and SirenHead are the small
SIREN building blocks used by the optional heads.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
The JEPA predictor head. Maps a target-token coordinate set to
predicted target-token features, given context tokens and a
conditioning vector. Operating conditions enter the model here
(via the cond argument), projected once and threaded into every
self- and cross-attention block. Accepts both unbatched (rank-2
context features) and padded batched (rank-3) inputs;
target_positions and cond are broadcast across the batch when
their leading dim is 1. The forward signature uses target_positions
as the parameter name (not target_coords) for consistency with
the rest of the model API.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Owns context encoder, target encoder, and decoder, and wires them
together. encode_context runs both encoders and emits a dict with
context tokens, target tokens, and the decoder-side cond_global.
decode_queries decodes a target token set at supplied query
positions, optionally producing a per-query mask logit when the
mask head is enabled. forward_single and forward_batch are
convenience wrappers chaining the two phases for unbatched and
padded batched inputs respectively. Public args use context_pos /
context_feat naming; gen_params is used to build cond_global but
is not threaded into the encoders.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
self.update_mlps = nn.ModuleList()
for _ in range(self.num_layers):
self.msg_mlps.append(
nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp.

)
)
self.gate_mlps.append(
nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp.

)
)
self.update_mlps.append(
nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp.

nn.GELU(),
nn.Linear(self.dim, self.dim),
)
self.attn_proj = nn.Sequential(

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(dim, dim, num_heads, act_layer=nn.GELU)

self.q_proj = nn.Linear(self.dim, self.dim)
self.k_proj = nn.Linear(self.dim, self.dim)
self.v_proj = nn.Linear(self.dim, self.dim)
self.pos_proj = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(3, dim, dim, act_layer=nn.GELU)

nn.GELU(),
nn.Linear(self.dim, self.dim),
)
self.attn_proj = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(dim, dim, num_heads, act_layer=nn.GELU)

self.q_proj = nn.Linear(self.dim, self.dim)
self.k_proj = nn.Linear(self.dim, self.dim)
self.v_proj = nn.Linear(self.dim, self.dim)
self.pos_proj = nn.Sequential(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use physicsnemo's Mlp: Mlp(3, dim, dim, act_layer=nn.GELU)

)


class SineLayer(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop SineLayer, use physicsnemo.nn.SirenLayer; keep SirenHead as a thin local composition but build it from the library's SirenLayer instead of the local one.

]
)
self.trunk = nn.Sequential(
nn.LayerNorm(int(token_dim)),

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
LayerNorm(int(token_dim))

):
super().__init__()
hidden = max(1, int(mlp_ratio)) * int(dim)
self.norm = nn.LayerNorm(int(dim))

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm = LayerNorm(int(dim))

self.neighbor_k = int(neighbor_k)
self.dilation = int(max(1, dilation))
self.knn_chunk_size = int(knn_chunk_size)
self.norm = nn.LayerNorm(self.dim)

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm = LayerNorm(self.dim)

self.head_dim = self.dim // self.num_heads
self.neighbor_k = int(neighbor_k)
self.knn_chunk_size = int(knn_chunk_size)
self.norm_q = nn.LayerNorm(self.dim)

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm_q = LayerNorm(self.dim)

self.knn_chunk_size = int(knn_chunk_size)
self.norm_q = nn.LayerNorm(self.dim)
self.adaln_zero = bool(adaln_zero)
self.norm_kv = nn.LayerNorm(self.dim)

@mnabian mnabian Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm_kv = LayerNorm(self.dim)

for _ in range(self.depth)
]
)
self.out_norm = nn.LayerNorm(self.hidden_dim)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.out_norm = LayerNorm(self.hidden_dim)

nn.Linear(self.hidden_dim, self.point_feature_dim),
)
)
self.out_norm = nn.LayerNorm(self.point_feature_dim)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.out_norm = LayerNorm(self.point_feature_dim)

for i in range(int(num_layers))
]
)
self.out_norm = nn.LayerNorm(int(token_dim))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).

from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.out_norm = LayerNorm(int(token_dim))


cluster_size = (
max_tokens if self.cluster_size is None else int(self.cluster_size)
)

@mnabian mnabian Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cluster_size=None for the neighborhood-cluster strategies silently defaults to max_tokens (e.g. 512 neighbors/token) — almost certainly unintended, and inconsistent with the data_prototype_cluster path which defaults to 16. Recommend failing fast: require cluster_size for random_cluster/fps_cluster/voxel_fps_cluster, and drop the max_tokens fallback.

suggestion: In __init__, after the strategy/prototype validation:

if (
    self.strategy in {"random_cluster", "fps_cluster", "voxel_fps_cluster"}
    and self.cluster_size is None
):
    raise ValueError(
        f"cluster_size must be provided for tokenizer_strategy='{self.strategy}'."
    )

Then in tokenize_with_clusters, replace the fallback (lines 423–425):

cluster_size = (
    max_tokens if self.cluster_size is None else int(self.cluster_size)
)

with:

cluster_size = int(self.cluster_size)

if self.prototype_knn_k is not None
else self.cluster_size
)
k = max(1, min(int(k if k is not None else 16), n_points))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_prototype_cluster with zero-point inputs calls KNN on empty tensor before early-return guard.

context_feat=context_feat,
)
cond_global = gen_params.unsqueeze(0)
if self.trunk.include_geometry_global_in_decoder_cond:

@mnabian mnabian Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encode_geometry has a non-trivial branch when include_geometry_global_in_decoder_cond=True (lines 245-252): it pulls the context global token and concatenates it onto cond_global, widening the decoder conditioning. Every test builds the trunk with this flag False, so the concat branch and the resulting wider cond_global are never exercised. Add a model built with the flag True and assert cond_global length grows by the context-global token dim.

Update: Based on the newer previous comment, you can ally this comment directly to trunk._build_cond_global_single.

query_sdf: Float[torch.Tensor, "Nq 1"] | None = None,
conditions: Float[torch.Tensor, "B Cond"] | None = None,
) -> Float[torch.Tensor, "Nq C"]:
if not torch.compiler.is_compiling():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AeroJEPA.forward raises ValueError when context_pos/context_feat are not rank 2 (line 474) or when their point counts disagree (line 479), but no test triggers either guard. Add two tests asserting pytest.raises(ValueError, match="rank 2") for a rank-3 context_pos and match="agree on the point count" for mismatched N between pos and feat.

query_pos=query_pos, query_sdf=query_sdf, cond=cond
)
)
for block in self.cross_blocks:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hot path: QueryTokenDecoder._decode_chunk, run for every query chunk over the full surface/volume field (millions of query points at inference via decode_field_chunked). Each LocalTokenCrossAttentionBlock.forward independently calls chunked_knn_indices(query_coords=query_pos, key_coords=token_coords, ...) (decoder.py:458, attention_blocks.py), but within a chunk query_pos and token_coords are identical across all cross_attention_layers blocks and the cross-blocks use no dilation, so the same k-NN graph is recomputed cross_attention_layers times per chunk. Compute the neighbor index once per chunk and pass it into the blocks (e.g. an optional precomputed-idx argument), reusing it across layers. Same redundancy pattern applies to PointTransformer.encode_single (point.py:606) when the dilation schedule is constant.

err = err * cw
weight_sum = float(channel_weights.to(dtype=err.dtype).sum().item())

point_weight_t = None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point_weights argument is documented and implemented in all four losses (mse_loss 106-111/122-124, relative_l2_loss 230-235, relative_mse_loss 414-432, plus the hybrid), but no test ever passes point_weights. None of the per-point weighting / weighted-denominator branches are exercised. Add at least one test asserting that uniform point_weights=ones matches the unweighted result and that a zeroed weight on a row drops that row's contribution.

point_weight_t = point_weight_t.unsqueeze(-1)
err = err * point_weight_t

if mask is not None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_mse_mask_drops_invalid_rows (test/experimental/models/aerojepa/losses/test_reconstruction.py:77-84) only ever passes an all-True mask (torch.ones(8, dtype=torch.bool)) and asserts the result equals the no-mask path. With an all-True mask, err * mask_t is a no-op and denom_weights.sum() equals the full row count, so the masked branch (reconstruction.py:113-120, including the denom_weights.sum().clamp_min(1.0) denominator) is only covered vacuously. The test's own docstring claims "masked-out rows contribute neither to the numerator nor the denominator," yet no row is ever masked out — a regression that stopped excluding masked rows would still pass. Add a test with a partial mask: set some rows to mask=False, fill those target rows with large values, and assert the loss equals mse_loss computed on only the kept rows.

device: torch.device,
) -> None:
"""Load model weights, applying EMA shadow when present and requested."""
payload = torch.load(ckpt_path, map_location=device, weights_only=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

payload = torch.load(ckpt_path, map_location=device, weights_only=False) deserializes the checkpoint with full pickle support. Loading a .pt from an untrusted source would execute arbitrary code via a malicious pickle __reduce__. Impact is limited here: the loaded payload is consumed only as tensor state_dicts (payload["model"], payload["ema_shadow"]) plus a scalar ema_decay, so weights_only=True would work and is the safer default.

context = {"target_tokens": target_tokens, "cond_global": cond_global}
preds = []
n = int(query_pos.shape[0])
with torch.autocast(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decode_field_chunked is only tested with precision="fp32" (test_aerojepa.py:179), so the autocast-enabled path (lines 393-401, where enabled=True for fp16/bf16) is never run. On a CUDA device this is the path that actually matters for the memory-bounded use case. Add a parametrized test over precision in {"fp16","bf16"} asserting the returned tensor is finite, CPU, and shape (Nq, C).

"""
target_encoder = self.trunk.target_encoder
base_encoder = getattr(target_encoder, "encoder", None)
if base_encoder is None or not hasattr(base_encoder, "_tokenize_single"):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_target_token_coords raises ValueError when the target encoder lacks the _tokenize_single path (lines 449-452), but only the happy path is tested (test_aerojepa.py:137). The defensive branch that protects callers swapping in a non-transformer target encoder is untested. A small test passing a stub trunk whose target_encoder has no encoder._tokenize_single and asserting the ValueError would cover it.

return selected


def _assign_points(

@mnabian mnabian Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assign_points is a k=1 nearest-center query — replace the manual cdist+min+chunk loop with nn.functional.knn(centers, points, k=1), which chunks internally and uses the cuML/scipy backends. The returned distance is unused by _run_chunked_kmeans (assign, _ = …), so the dist and chunk_size args can be dropped too.

from physicsnemo.nn.functional import knn

def _assign_points(points: torch.Tensor, centers: torch.Tensor) -> torch.Tensor:
    """Assign each point to its nearest center (k=1 kNN)."""
    idx, _ = knn(centers, points, k=1)   # neighbors of `points` among `centers`
    return idx[:, 0]

and the call site in _run_chunked_kmeans:

assign = _assign_points(points, centers)

(chunk_size is no longer threaded into _assign_points.)

Comment thread CHANGELOG.md

- Adds radiation transport example (`examples/nuclear_engineering/radiation_transport`)
- Adds agent skills structure, and initial skill for 'discoverability'.
- Adds top-level AeroJEPA model under

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AeroJEPA CHANGELOG is far longer and more granular than any other entry. Consider condensing to a single concise bullet (optionally split model vs. recipe) matching the house style.

token_features=token_features,
gen_params=gen_params,
)
for block in self.blocks:

@mnabian mnabian Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kNN distance search is recomputed in every block although coordinates are static within each stack. Compute the sorted neighbor list once (at k = neighbor_k × max(dilation)), then have each block apply its own dilation stride - a pure indexing op. This removes N−1 redundant distance/sort passes per stack without changing behavior.

gen_embed.unsqueeze(1).expand(-1, int(token_positions.shape[1]), -1),
token_mask,
)
for block in self.blocks:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same static-coords kNN recomputation here

query_pos=query_pos, query_sdf=query_sdf, cond=cond
)
)
for block in self.cross_blocks:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same static-coords kNN recomputation here

cond=cond_chunk,
)
)
for block in self.cross_blocks:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same static-coords kNN recomputation here

context_mask,
)

for self_block, cross_block in zip(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same static-coords kNN recomputation here

torch.Size([30, 4])
"""

def __init__(

@mnabian mnabian Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AeroJEPA subclasses physicsnemo.core.Module, so its __init__ args must be JSON-serializable. The only exception is args that are themselves physicsnemo.Module instances. But trunk and predictor here are plain torch.nn.Module:

model = AeroJEPA(trunk=trunk, predictor=predictor)
model.save("m.mdlus")  # raises during json.dumps

Suggested fix: You can take JSON-serializable config in __init__ and build trunk/predictor internally — the FNO/DoMINO pattern. Alternatively, convert the full submodule tree to physicsnemo.Module.

token_dim : int
Feature dimension of context and predicted target tokens.
cond_dim : int
Dimension of the conditioning vector. ``0`` disables

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blocks are actually always constructed with conditioning_dim=self.hidden_dim and run AdaLN on the zero cond_embed — so they apply a learned constant modulation, not nothing. Either gate conditioning_dim on cond_dim>0 to truly disable, or fix the docstring.

context_pos=context_pos,
context_feat=context_feat,
)
cond_global = gen_params.unsqueeze(0)

@mnabian mnabian Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a reimplementation of the trunk._build_cond_global_single. Call the trunk helper instead.

context_feat: Float[torch.Tensor, "N D_feat"],
gen_params: Float[torch.Tensor, "G"],
query_pos: Float[torch.Tensor, "Nq D_pos"],
query_sdf: Float[torch.Tensor, "Nq 1"] | None = None,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query_sdf=None not validated against decoder.use_sdf. Add an early guard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants