Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690
Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690fgiral000 wants to merge 52 commits into
Conversation
Create an empty subpackage for the AeroJEPA reusable building blocks (attention blocks, geometry tokenizer, context/target encoders, decoder, predictor) that land in subsequent commits. Establishes the SPDX license header, module docstring, and ``__all__`` placeholder so that follow-up commits only need to register new public symbols. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
TokenSet bundles token features with their geometric coordinates and optional mask, global token, and auxiliary side data; EncoderOutput is a thin wrapper used by context and target encoders to surface a global summary alongside the per-token output. Includes raw-string docstrings with Parameters/Examples sections (three executable doctests), modern union syntax, and the SPDX header. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A deterministic log-frequency sinusoidal positional encoding used to lift continuous query coordinates into a high-dimensional feature space before the implicit decoder consumes them. Distinct from physicsnemo.nn.FourierEmbedding (random Gaussian frequencies on scalar timesteps); this variant uses fixed log-powers of pi on multi-dim coordinates with the standard sin/cos band layout. Includes an out_dim property, jaxtyping on forward, and an executable doctest. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A private _gpu_knn module bundling chunked torch.cdist plus topk for building homogeneous (gpu_knn_self) and bipartite (gpu_knn_bipartite) k-NN graphs and inverse-distance interpolation (gpu_knn_interpolate). Pure PyTorch, no warp or custom CUDA — works on CPU too, just slower. The leading underscore on the filename makes the module package- private; callers live inside the aerojepa subpackage only. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A token_utils module with the helpers used by the AeroJEPA tokenizer, encoders and attention blocks: gather_rows, counts_to_mask, flatten_padded_batch / unflatten_to_padded, compute_batch_offset_step, flatten_batched_coords, chunked_knn_indices (CPU/GPU dispatcher with the AE_KNN_BACKEND env override), masked_mean, trim_batched_tokens, and pad_token_sets. Behavior preserved; types modernized and the TokenSet import is package-relative. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A reusable trio of attention building blocks: ResidualMLP (pre-norm residual MLP with optional AdaLN / AdaLN-Zero conditioning), LocalPointTransformerBlock (local self-attention over a per-point k-NN graph with learned relative-position bias), and LocalTokenCrossAttentionBlock (cross-attention from queries to a per-query k-NN of context tokens, with a 5-way conditioning MLP that modulates query and key/value sides independently). Behavior preserved: zero-init conditioning MLPs give an identity transform at construction time, and the N<=1 / empty-input fallbacks short-circuit the same way they do upstream. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A tokenizer module that reduces a raw point set to a bounded token budget before attention. Seven strategies: identity, random, FPS, random/FPS/voxel-FPS cluster pooling, and prototype-anchored clustering. The cluster strategies return the kNN indices that link each token center back to the source points, allowing a downstream encoder to replace the default feature mean with a learned pooling (e.g. the message-passing PointClusterGraphPool that lands with the encoders in PR NVIDIA#3). Behavior preserved including the non-persistent prototype_coords buffer and the per-sample loop used by the prototype strategy. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Build the fixed k-means anchor set used by the data_prototype_cluster tokenizer strategy. The build pass walks a training dataset, tokenizes each sample to obtain candidate token coordinates, optionally subsamples, runs chunked Lloyd-iteration k-means with empty-cluster FPS refill, sorts the centers lexicographically, and serializes them with a JSON metadata blob. Two load functions (target / context - identical file layout) and two ensure_* helpers (load-if-exists else build) round out the public surface. Behavior preserved; the seed argument governs k-means initialization and candidate subsampling but not the tokenizer pass, which intentionally uses random sampling. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Export the 23 public symbols from the six source modules at the package level: TokenSet/EncoderOutput dataclasses, FourierPositionalEncoding, ResidualMLP and the two local attention blocks, PointCloudTokenizer, the ten batching/mask/kNN helpers, and the six prototype anchor build/load functions. Module docstring tightened to reflect the actual contents (encoders / decoder / predictor land in physicsnemo.experimental.models.aerojepa in a later PR). The package-private _gpu_knn helpers remain accessible via their submodule path but are intentionally not re-exported. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Re-export the five AeroJEPA nn.Module layer classes (FourierPositionalEncoding, ResidualMLP, LocalPointTransformerBlock, LocalTokenCrossAttentionBlock, PointCloudTokenizer) at the experimental.nn parent namespace, alongside the existing FLARE and DiffusionUNet3D family. Data types (TokenSet, EncoderOutput), batching/mask helpers, and prototype-anchor builders stay scoped to the aerojepa subpackage to keep the parent namespace focused on actual layers. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Tests for TokenSet and EncoderOutput covering construction (both batched and unbatched), the is_batched / token_dim properties, the with_updates immutability + selective-replacement contract, and the independence of the default aux dict across instances. Uses the shared device fixture so the CUDA path runs when available. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…locks Six new test files covering positional_encoding, attention_blocks, point_tokenizer, token_utils, _gpu_knn, and prototype_anchors. 85 new tests covering: constructor validation paths, forward output shapes, edge cases (N<=1 LPT fallback, empty cross-attention, empty/single- point kNN, missing voxel_size, non-persistent prototype_coords buffer), identity-at-init of AdaLN-Zero conditioning MLPs, the AE_KNN_BACKEND env override, and build/load round-trips with a tiny fake dataset. All tests use the shared device fixture so CUDA runs when available; CPU run is 18 s wall. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Documents the new physicsnemo.experimental.nn.aerojepa subpackage contributed across the preceding 12 commits on this branch: token dataclasses, Fourier positional encoding, ResidualMLP, the two local attention blocks, PointCloudTokenizer, token batching/mask/kNN helpers, and prototype anchor utilities, plus the parent-namespace re-export of the five layer classes. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Create an empty subpackage for JEPA-style losses and regularizers (SIGReg, TokenLatentSIGReg, the padding-aware masking helpers, and the reconstruction loss family) that land in subsequent commits. Establishes the SPDX license header, module docstring, and ``__all__`` placeholder so that follow-up commits only need to register new public symbols. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Two utilities used by SIGReg / TokenLatentSIGReg to flatten padded batched token features and reshape them into the (T, B, D) layout SIGReg expects. flatten_valid_token_features is a passthrough on rank-2 inputs and uses boolean masking on rank-3 inputs; reshape_token_features_for_sigreg adds the leading T=1 axis and emits a zero-element (1, 0, D) placeholder when the mask removes every row. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
SIGReg pushes a learned latent toward N(0, I) by comparing the empirical Fourier characteristic function of random projections against the reference Gaussian one on a uniform knot grid (the LeWorldModel construction). Three non-learnable buffers cache the knot positions, the reference window, and the trapezoidal + window-weighted integration weights. TokenLatentSIGReg is a thin wrapper that accepts (B, N, D) or (N, D) features plus an optional mask, drops padded rows via the masking helpers, and short-circuits to a zero scalar when the mask removes every row. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Four loss families exposed as functional and nn.Module variants: mse_loss / MSELoss (channel-weighted MSE with mask + point weights), relative_l2_loss / RelativeL2Loss (per-channel relative L2 averaged over channels), relative_mse_loss / RelativeMSELoss (relative MSE with selectable pointwise vs channel_max normalization), and the relative_l2_mse_loss / RelativeL2MSELoss hybrid that linearly combines the L2 and MSE terms. Channel weights are stored as a persistent float32 buffer on the Module variants when supplied, and as a non-persistent None buffer otherwise. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Move the JEPA losses subpackage from physicsnemo.experimental.metrics.jepa to .metrics.aerojepa so it mirrors the nn.aerojepa naming. Populate the package __init__ with the 12 public re-exports from masking, sigreg, and reconstruction (flatten/reshape token helpers, SIGReg/TokenLatentSIGReg, and the four reconstruction loss families both functional and as nn.Module). Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Three test files mirroring the source modules: test_masking, test_sigreg, test_reconstruction. 37 tests covering constructor validation, forward shape, edge cases (rank-1, empty batch, all-False mask), the SIGReg buffer layout, state_dict persistence of channel_weights on the reconstruction Module variants, both modes of relative_mse_loss, and the hybrid degenerating to either of its two sub-losses when the corresponding weight is zero. CPU run is 4 s wall; device fixture picks up CUDA automatically. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Documents the new physicsnemo.experimental.metrics.aerojepa subpackage contributed across the preceding 6 commits on this branch: SIGReg / TokenLatentSIGReg regularizers, masking helpers, and the four reconstruction loss families. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Create an empty subpackage for the top-level AeroJEPA model and its model-specific subcomponents (context/target/point encoders, decoder, predictor, trunk) that land in subsequent commits. Module docstring points readers to experimental.nn.aerojepa for the reusable building blocks the model is composed from. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Abstract base classes BaseContextEncoder and BaseTargetEncoder (plus the encoders subpackage init) define the contract concrete encoders must satisfy: a required forward returning an EncoderOutput and an optional forward_batched gated by a supports_batched_forward class flag. The context encoder's forward args are named context_pos / context_feat (these bundle the boundary and any volumetric samples in whole-domain models; the SDF channel in context_feat distinguishes the two halves at inference). The target encoder keeps the surface / volume split because training-time subsamplings for the two are intentionally decoupled. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Context tokens are produced from geometry alone - operating conditions enter the model downstream at the predictor head, not at the context branch. Remove gen_params from BaseContextEncoder forward and forward_batched signatures. Class docstring spells out the intent. BaseTargetEncoder is untouched. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
PointTransformer (point.py) is a point-cloud encoder building block: tokenizes the input via PointCloudTokenizer, embeds tokens with a Fourier positional encoding plus per-feature linear projection, optionally adds a conditioning vector, runs a stack of LocalPointTransformerBlock layers with configurable dilation, and emits an EncoderOutput. Two entry points - encode_single for unbatched inputs and forward_batched for padded batches with per-batch coordinate offsetting so the inner k-NN does not mix tokens across batch items. The same file carries the build_geometry_features helper (assembles per-point features from positions and optional SDF / normals / n-dot channels) and the message-passing PointClusterGraphPool used when tokenizer_cluster_pooling='graph'. ContextTransformer (context.py) is the concrete BaseContextEncoder. Takes context_pos and context_feat - no gen_params, since operating conditions enter the model downstream at the predictor head. Internally wraps PointTransformer with conditioning disabled. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Mirror the context-side change: target encoders take their inputs straight, with no gen_params threaded through. Operating conditions enter the model only at the predictor head. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
JEPA target encoders are self-attention only. Remove context_tokens from forward and forward_batched. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Concrete BaseTargetEncoder that wraps an inner PointTransformer. Forward concatenates surface and volume into one bundled point set; forward_batched weaves variable-length surface and volume halves per batch via counts_to_mask. Self-attention only. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Implicit field decoder driven by cross-attention to target tokens. Per-query embedding is a Fourier positional encoding plus optional SDF channel and optional cond vector; cross-attention to the target token set refines it, a trunk MLP and head produce the output. Several optional behaviors wire in: wall-velocity gate, pressure split head (MLP or SIREN), final SIREN refinement, extra SDF features. Both forward (single) and forward_batched (padded) process queries in chunks of query_chunk_size and return (pred, query_embeddings). SineLayer and SirenHead are the small SIREN building blocks used by the optional heads. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
The JEPA predictor head. Maps a target-token coordinate set to predicted target-token features, given context tokens and a conditioning vector. Operating conditions enter the model here (via the cond argument), projected once and threaded into every self- and cross-attention block. Accepts both unbatched (rank-2 context features) and padded batched (rank-3) inputs; target_positions and cond are broadcast across the batch when their leading dim is 1. The forward signature uses target_positions as the parameter name (not target_coords) for consistency with the rest of the model API. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Owns context encoder, target encoder, and decoder, and wires them together. encode_context runs both encoders and emits a dict with context tokens, target tokens, and the decoder-side cond_global. decode_queries decodes a target token set at supplied query positions, optionally producing a per-query mask logit when the mask head is enabled. forward_single and forward_batch are convenience wrappers chaining the two phases for unbatched and padded batched inputs respectively. Public args use context_pos / context_feat naming; gen_params is used to build cond_global but is not threaded into the encoders. Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
| self.update_mlps = nn.ModuleList() | ||
| for _ in range(self.num_layers): | ||
| self.msg_mlps.append( | ||
| nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp.
| ) | ||
| ) | ||
| self.gate_mlps.append( | ||
| nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp.
| ) | ||
| ) | ||
| self.update_mlps.append( | ||
| nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp.
| nn.GELU(), | ||
| nn.Linear(self.dim, self.dim), | ||
| ) | ||
| self.attn_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(dim, dim, num_heads, act_layer=nn.GELU)
| self.q_proj = nn.Linear(self.dim, self.dim) | ||
| self.k_proj = nn.Linear(self.dim, self.dim) | ||
| self.v_proj = nn.Linear(self.dim, self.dim) | ||
| self.pos_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(3, dim, dim, act_layer=nn.GELU)
| nn.GELU(), | ||
| nn.Linear(self.dim, self.dim), | ||
| ) | ||
| self.attn_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(dim, dim, num_heads, act_layer=nn.GELU)
| self.q_proj = nn.Linear(self.dim, self.dim) | ||
| self.k_proj = nn.Linear(self.dim, self.dim) | ||
| self.v_proj = nn.Linear(self.dim, self.dim) | ||
| self.pos_proj = nn.Sequential( |
There was a problem hiding this comment.
Please use physicsnemo's Mlp: Mlp(3, dim, dim, act_layer=nn.GELU)
| ) | ||
|
|
||
|
|
||
| class SineLayer(nn.Module): |
There was a problem hiding this comment.
Drop SineLayer, use physicsnemo.nn.SirenLayer; keep SirenHead as a thin local composition but build it from the library's SirenLayer instead of the local one.
| ] | ||
| ) | ||
| self.trunk = nn.Sequential( | ||
| nn.LayerNorm(int(token_dim)), |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
LayerNorm(int(token_dim))
| ): | ||
| super().__init__() | ||
| hidden = max(1, int(mlp_ratio)) * int(dim) | ||
| self.norm = nn.LayerNorm(int(dim)) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm = LayerNorm(int(dim))
| self.neighbor_k = int(neighbor_k) | ||
| self.dilation = int(max(1, dilation)) | ||
| self.knn_chunk_size = int(knn_chunk_size) | ||
| self.norm = nn.LayerNorm(self.dim) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm = LayerNorm(self.dim)
| self.head_dim = self.dim // self.num_heads | ||
| self.neighbor_k = int(neighbor_k) | ||
| self.knn_chunk_size = int(knn_chunk_size) | ||
| self.norm_q = nn.LayerNorm(self.dim) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm_q = LayerNorm(self.dim)
| self.knn_chunk_size = int(knn_chunk_size) | ||
| self.norm_q = nn.LayerNorm(self.dim) | ||
| self.adaln_zero = bool(adaln_zero) | ||
| self.norm_kv = nn.LayerNorm(self.dim) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.norm_kv = LayerNorm(self.dim)
| for _ in range(self.depth) | ||
| ] | ||
| ) | ||
| self.out_norm = nn.LayerNorm(self.hidden_dim) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.out_norm = LayerNorm(self.hidden_dim)
| nn.Linear(self.hidden_dim, self.point_feature_dim), | ||
| ) | ||
| ) | ||
| self.out_norm = nn.LayerNorm(self.point_feature_dim) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.out_norm = LayerNorm(self.point_feature_dim)
| for i in range(int(num_layers)) | ||
| ] | ||
| ) | ||
| self.out_norm = nn.LayerNorm(int(token_dim)) |
There was a problem hiding this comment.
Use PhysicsNeMo's TE-aware LayerNorm instead of torch.nn.LayerNorm to pick up Transformer Engine's high-performance LayerNorm (notably the faster backward pass).
from physicsnemo.nn.module.layer_norm import LayerNorm
...
self.out_norm = LayerNorm(int(token_dim))
|
|
||
| cluster_size = ( | ||
| max_tokens if self.cluster_size is None else int(self.cluster_size) | ||
| ) |
There was a problem hiding this comment.
cluster_size=None for the neighborhood-cluster strategies silently defaults to max_tokens (e.g. 512 neighbors/token) — almost certainly unintended, and inconsistent with the data_prototype_cluster path which defaults to 16. Recommend failing fast: require cluster_size for random_cluster/fps_cluster/voxel_fps_cluster, and drop the max_tokens fallback.
suggestion: In __init__, after the strategy/prototype validation:
if (
self.strategy in {"random_cluster", "fps_cluster", "voxel_fps_cluster"}
and self.cluster_size is None
):
raise ValueError(
f"cluster_size must be provided for tokenizer_strategy='{self.strategy}'."
)
Then in tokenize_with_clusters, replace the fallback (lines 423–425):
cluster_size = (
max_tokens if self.cluster_size is None else int(self.cluster_size)
)
with:
cluster_size = int(self.cluster_size)
| if self.prototype_knn_k is not None | ||
| else self.cluster_size | ||
| ) | ||
| k = max(1, min(int(k if k is not None else 16), n_points)) |
There was a problem hiding this comment.
data_prototype_cluster with zero-point inputs calls KNN on empty tensor before early-return guard.
| context_feat=context_feat, | ||
| ) | ||
| cond_global = gen_params.unsqueeze(0) | ||
| if self.trunk.include_geometry_global_in_decoder_cond: |
There was a problem hiding this comment.
encode_geometry has a non-trivial branch when include_geometry_global_in_decoder_cond=True (lines 245-252): it pulls the context global token and concatenates it onto cond_global, widening the decoder conditioning. Every test builds the trunk with this flag False, so the concat branch and the resulting wider cond_global are never exercised. Add a model built with the flag True and assert cond_global length grows by the context-global token dim.
Update: Based on the newer previous comment, you can ally this comment directly to trunk._build_cond_global_single.
| query_sdf: Float[torch.Tensor, "Nq 1"] | None = None, | ||
| conditions: Float[torch.Tensor, "B Cond"] | None = None, | ||
| ) -> Float[torch.Tensor, "Nq C"]: | ||
| if not torch.compiler.is_compiling(): |
There was a problem hiding this comment.
AeroJEPA.forward raises ValueError when context_pos/context_feat are not rank 2 (line 474) or when their point counts disagree (line 479), but no test triggers either guard. Add two tests asserting pytest.raises(ValueError, match="rank 2") for a rank-3 context_pos and match="agree on the point count" for mismatched N between pos and feat.
| query_pos=query_pos, query_sdf=query_sdf, cond=cond | ||
| ) | ||
| ) | ||
| for block in self.cross_blocks: |
There was a problem hiding this comment.
Hot path: QueryTokenDecoder._decode_chunk, run for every query chunk over the full surface/volume field (millions of query points at inference via decode_field_chunked). Each LocalTokenCrossAttentionBlock.forward independently calls chunked_knn_indices(query_coords=query_pos, key_coords=token_coords, ...) (decoder.py:458, attention_blocks.py), but within a chunk query_pos and token_coords are identical across all cross_attention_layers blocks and the cross-blocks use no dilation, so the same k-NN graph is recomputed cross_attention_layers times per chunk. Compute the neighbor index once per chunk and pass it into the blocks (e.g. an optional precomputed-idx argument), reusing it across layers. Same redundancy pattern applies to PointTransformer.encode_single (point.py:606) when the dilation schedule is constant.
| err = err * cw | ||
| weight_sum = float(channel_weights.to(dtype=err.dtype).sum().item()) | ||
|
|
||
| point_weight_t = None |
There was a problem hiding this comment.
The point_weights argument is documented and implemented in all four losses (mse_loss 106-111/122-124, relative_l2_loss 230-235, relative_mse_loss 414-432, plus the hybrid), but no test ever passes point_weights. None of the per-point weighting / weighted-denominator branches are exercised. Add at least one test asserting that uniform point_weights=ones matches the unweighted result and that a zeroed weight on a row drops that row's contribution.
| point_weight_t = point_weight_t.unsqueeze(-1) | ||
| err = err * point_weight_t | ||
|
|
||
| if mask is not None: |
There was a problem hiding this comment.
test_mse_mask_drops_invalid_rows (test/experimental/models/aerojepa/losses/test_reconstruction.py:77-84) only ever passes an all-True mask (torch.ones(8, dtype=torch.bool)) and asserts the result equals the no-mask path. With an all-True mask, err * mask_t is a no-op and denom_weights.sum() equals the full row count, so the masked branch (reconstruction.py:113-120, including the denom_weights.sum().clamp_min(1.0) denominator) is only covered vacuously. The test's own docstring claims "masked-out rows contribute neither to the numerator nor the denominator," yet no row is ever masked out — a regression that stopped excluding masked rows would still pass. Add a test with a partial mask: set some rows to mask=False, fill those target rows with large values, and assert the loss equals mse_loss computed on only the kept rows.
| device: torch.device, | ||
| ) -> None: | ||
| """Load model weights, applying EMA shadow when present and requested.""" | ||
| payload = torch.load(ckpt_path, map_location=device, weights_only=False) |
There was a problem hiding this comment.
payload = torch.load(ckpt_path, map_location=device, weights_only=False) deserializes the checkpoint with full pickle support. Loading a .pt from an untrusted source would execute arbitrary code via a malicious pickle __reduce__. Impact is limited here: the loaded payload is consumed only as tensor state_dicts (payload["model"], payload["ema_shadow"]) plus a scalar ema_decay, so weights_only=True would work and is the safer default.
| context = {"target_tokens": target_tokens, "cond_global": cond_global} | ||
| preds = [] | ||
| n = int(query_pos.shape[0]) | ||
| with torch.autocast( |
There was a problem hiding this comment.
decode_field_chunked is only tested with precision="fp32" (test_aerojepa.py:179), so the autocast-enabled path (lines 393-401, where enabled=True for fp16/bf16) is never run. On a CUDA device this is the path that actually matters for the memory-bounded use case. Add a parametrized test over precision in {"fp16","bf16"} asserting the returned tensor is finite, CPU, and shape (Nq, C).
| """ | ||
| target_encoder = self.trunk.target_encoder | ||
| base_encoder = getattr(target_encoder, "encoder", None) | ||
| if base_encoder is None or not hasattr(base_encoder, "_tokenize_single"): |
There was a problem hiding this comment.
build_target_token_coords raises ValueError when the target encoder lacks the _tokenize_single path (lines 449-452), but only the happy path is tested (test_aerojepa.py:137). The defensive branch that protects callers swapping in a non-transformer target encoder is untested. A small test passing a stub trunk whose target_encoder has no encoder._tokenize_single and asserting the ValueError would cover it.
| return selected | ||
|
|
||
|
|
||
| def _assign_points( |
There was a problem hiding this comment.
_assign_points is a k=1 nearest-center query — replace the manual cdist+min+chunk loop with nn.functional.knn(centers, points, k=1), which chunks internally and uses the cuML/scipy backends. The returned distance is unused by _run_chunked_kmeans (assign, _ = …), so the dist and chunk_size args can be dropped too.
from physicsnemo.nn.functional import knn
def _assign_points(points: torch.Tensor, centers: torch.Tensor) -> torch.Tensor:
"""Assign each point to its nearest center (k=1 kNN)."""
idx, _ = knn(centers, points, k=1) # neighbors of `points` among `centers`
return idx[:, 0]
and the call site in _run_chunked_kmeans:
assign = _assign_points(points, centers)
(chunk_size is no longer threaded into _assign_points.)
|
|
||
| - Adds radiation transport example (`examples/nuclear_engineering/radiation_transport`) | ||
| - Adds agent skills structure, and initial skill for 'discoverability'. | ||
| - Adds top-level AeroJEPA model under |
There was a problem hiding this comment.
The AeroJEPA CHANGELOG is far longer and more granular than any other entry. Consider condensing to a single concise bullet (optionally split model vs. recipe) matching the house style.
| token_features=token_features, | ||
| gen_params=gen_params, | ||
| ) | ||
| for block in self.blocks: |
There was a problem hiding this comment.
The kNN distance search is recomputed in every block although coordinates are static within each stack. Compute the sorted neighbor list once (at k = neighbor_k × max(dilation)), then have each block apply its own dilation stride - a pure indexing op. This removes N−1 redundant distance/sort passes per stack without changing behavior.
| gen_embed.unsqueeze(1).expand(-1, int(token_positions.shape[1]), -1), | ||
| token_mask, | ||
| ) | ||
| for block in self.blocks: |
There was a problem hiding this comment.
same static-coords kNN recomputation here
| query_pos=query_pos, query_sdf=query_sdf, cond=cond | ||
| ) | ||
| ) | ||
| for block in self.cross_blocks: |
There was a problem hiding this comment.
same static-coords kNN recomputation here
| cond=cond_chunk, | ||
| ) | ||
| ) | ||
| for block in self.cross_blocks: |
There was a problem hiding this comment.
same static-coords kNN recomputation here
| context_mask, | ||
| ) | ||
|
|
||
| for self_block, cross_block in zip( |
There was a problem hiding this comment.
same static-coords kNN recomputation here
| torch.Size([30, 4]) | ||
| """ | ||
|
|
||
| def __init__( |
There was a problem hiding this comment.
AeroJEPA subclasses physicsnemo.core.Module, so its __init__ args must be JSON-serializable. The only exception is args that are themselves physicsnemo.Module instances. But trunk and predictor here are plain torch.nn.Module:
model = AeroJEPA(trunk=trunk, predictor=predictor)
model.save("m.mdlus") # raises during json.dumpsSuggested fix: You can take JSON-serializable config in __init__ and build trunk/predictor internally — the FNO/DoMINO pattern. Alternatively, convert the full submodule tree to physicsnemo.Module.
| token_dim : int | ||
| Feature dimension of context and predicted target tokens. | ||
| cond_dim : int | ||
| Dimension of the conditioning vector. ``0`` disables |
There was a problem hiding this comment.
The blocks are actually always constructed with conditioning_dim=self.hidden_dim and run AdaLN on the zero cond_embed — so they apply a learned constant modulation, not nothing. Either gate conditioning_dim on cond_dim>0 to truly disable, or fix the docstring.
| context_pos=context_pos, | ||
| context_feat=context_feat, | ||
| ) | ||
| cond_global = gen_params.unsqueeze(0) |
There was a problem hiding this comment.
This is a reimplementation of the trunk._build_cond_global_single. Call the trunk helper instead.
| context_feat: Float[torch.Tensor, "N D_feat"], | ||
| gen_params: Float[torch.Tensor, "G"], | ||
| query_pos: Float[torch.Tensor, "Nq D_pos"], | ||
| query_sdf: Float[torch.Tensor, "Nq 1"] | None = None, |
There was a problem hiding this comment.
query_sdf=None not validated against decoder.use_sdf. Add an early guard.
PhysicsNeMo Pull Request
Description
Adds the AeroJEPA model and a SuperWing tutorial recipe under
physicsnemo.experimentalandexamples/cfd/external_aerodynamics/.AeroJEPA is a Joint-Embedding Predictive Architecture for 3D
aerodynamic surrogate modeling: instead of mapping geometry directly to
a flow field, it predicts a latent representation of the flow from a
latent representation of the geometry and operating conditions, and
reconstructs the field through a continuous implicit decoder when
needed (Giral et al., arXiv:2605.05586).
What this PR delivers:
physicsnemo.experimental.models.aerojepa.AeroJEPAcomposes a context encoder, a target encoder, a query-tokenfield decoder (collectively
AeroJEPATrunk), and a JEPA predictorhead (
PrototypeTokenJEPAHead) into a singlephysicsnemo.core.module.Module. The training path takes contextpositions/features, independent target encoder surface/volume inputs,
and operating conditions; the predictor predicts target tokens, and
the decoder evaluates the field at user-supplied query points.
predictis a no-grad inference wrapper;decode_field_chunkedsupports memory-bounded evaluation over very large query sets.
Concrete encoders (
ContextTransformer,TargetTransformer,PointTransformer), theQueryTokenDecoder, and the encoder ABCsare all exposed as composable components.
physicsnemo.experimental.models.aerojepa.layers.TokenSetandEncoderOutputtoken dataclasses, a deterministicFourierPositionalEncoding,ResidualMLP, theLocalPointTransformerBlock/LocalTokenCrossAttentionBlockattention blocks (with optional AdaLN / AdaLN-Zero conditioning), the
PointCloudTokenizer(seven center-selection strategies with k-NNcluster pooling), token batching / mask / k-NN helpers, and prototype
anchor build / load utilities.
TokenSetandEncoderOutputarere-exported from the model package for convenience.
physicsnemo.experimental.models.aerojepa.losses.SIGRegandTokenLatentSIGReg(a sketch isotropic-Gaussianregularizer for latent-token distributions, with a padding-aware
wrapper), the
flatten_valid_token_features/reshape_token_features_for_sigregmasking helpers, and thereconstruction loss family (
MSELoss/RelativeL2Loss/RelativeMSELoss/RelativeL2MSELoss, each with functional andnn.Moduleforms, optional per-channel weights stored as apersistent buffer, optional per-point weights, and an optional
validity mask).
examples/cfd/external_aerodynamics/aerojepa. End-to-end Hydra-drivenworkflow on the public SuperWing dataset (Yang et al.,
arXiv:2512.14397): dataset download via the Hugging Face Hub
(
yunplus/SuperWing), automatic split-by-geometry manifest andper-channel normalization stats, JEPA training (reconstruction +
latent + SIGReg with linear warmups; AdamW +
warmup-cosine; optional EMA), checkpointed inference with chunked
decoding, three-panel
GT | Pred | |Error|field plots for the threesurface channels (
Cp,Cf_tau,Cf_z), per-channel relative-L2 /RMSE / MAE metrics on the test split, and a pressure-only CL/CD
post-processor that integrates the surface field and emits a per-case
CSV plus a parity scatter.
Checklist
Tests
test/experimental/models/aerojepa/(constructor + attribute checks, non-regression shape checks on the
encoders, decoder, predictor, trunk, top-level model, layers, and
losses).
pytest test/experimental/models/aerojepa/ -qpasseslocally on CPU (~20 s).
train.py -> inference.py -> superwing_metrics -> superwing_forces.Training losses decrease monotonically; inference produces field
plots, per-case field-error metrics, and a force-coefficient parity
scatter.
Dependencies
No new core dependencies. The example recipe adds optional
example-side dependencies in
examples/cfd/external_aerodynamics/aerojepa/requirements.txt(Hugging Face Hub for the dataset download, plotting and
post-processing utilities). Pre-commit hooks,
ruff,interrogate,markdownlint, and the SPDX license check pass on every file in thePR.