Skip to content

Commit 651c626

Browse files
ktangsaliKaustubh Tangsaliroot
authored
Add latent novelty query strategy using OODGuard (#1678)
* add latent novelty query strategy using OODGuard * linting, changelong addition and readme * improve viz * Address review comments --------- Co-authored-by: Kaustubh Tangsali <ktangsali@oci-nrt-cs-001-vscode-01.cm.cluster> Co-authored-by: root <root@pool0-03775.cm.cluster>
1 parent b3fe4f4 commit 651c626

12 files changed

Lines changed: 596 additions & 43 deletions

File tree

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3232
protocols and `physicsnemo.experimental.uq.VariationalGPHead`, with a
3333
layered structure (generic AL driver / GP-UQ recipe / aero adapter)
3434
designed for reuse on other UQ-based regression problems.
35+
- Adds `LatentNoveltyQueryStrategy` to the active-learning aero recipe,
36+
a third acquisition strategy that ranks unlabeled samples by their
37+
average kNN cosine distance in the encoder's learned geometry latent
38+
— reusing the same `OODGuard`
39+
(`physicsnemo.experimental.guardrails.embedded`) that flags
40+
out-of-distribution inputs at inference time. The guard is calibrated
41+
on the currently labeled set each round; round 1 falls back to
42+
class-balanced random because the calibration buffer is empty. New
43+
public `OODGuard.score_geometry()` method exposes the raw per-sample
44+
geometry-latent kNN distance as a continuous score for downstream
45+
consumers (e.g. AL acquisition) without the boolean thresholding /
46+
warning emission of `OODGuard.check()`.
3547

3648
### Changed
3749

71.9 KB
Loading
31.1 KB
Loading

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ The several examples inside PhysicsNeMo can be classified based on their domains
110110
### Active Learning
111111

112112
1. [Classify the famous two-moons data distribution using Active learning](./active_learning/moons/)
113+
2. [Active Learning for Surface-CFD Aerodynamic Surrogates](./cfd/external_aerodynamics/active_learning_aero/)
113114

114115
## Additional examples
115116

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
<!-- markdownlint-disable -->
2+
# Active Learning for Surface-CFD Aerodynamic Surrogates
3+
4+
This example lives under the CFD examples tree, alongside the surface-CFD
5+
backbone it builds on:
6+
[**`physicsnemo/examples/cfd/external_aerodynamics/active_learning_aero`**](../../cfd/external_aerodynamics/active_learning_aero/README.md).
7+
8+
It demonstrates end-to-end active learning on the
9+
[ShiftSUV](https://huggingface.co/datasets/luminary-shift/SUV) surface-CFD
10+
dataset using an uncertainty-aware GeoTransolver + Variational GP head, with
11+
three plug-and-play acquisition strategies (UQ-driven, class-balanced random,
12+
and latent-novelty). The AL loop itself is problem-agnostic — only the
13+
physics/metrology hooks are CFD-specific — so the same recipe can drive any
14+
uncertainty-quantified regression task.
15+
16+
See the [full README](../../cfd/external_aerodynamics/active_learning_aero/README.md)
17+
for the recipe overview, configuration, results, and adapting-to-a-new-problem
18+
guide.

examples/cfd/external_aerodynamics/active_learning_aero/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ src/manifests/*.json
1515
# Local backup snapshots (e.g. manifest_class_*.json.bak_*).
1616
*.bak
1717
*.bak_*
18+
slurm_logs/

examples/cfd/external_aerodynamics/active_learning_aero/README.md

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ runs/geotransolver/surface/al_shiftsuv_uq/joint_uq/
222222

223223
## Acquisition strategies
224224

225-
Three strategies ship with this example, all implementing the
225+
Four strategies ship with this example, all implementing the
226226
`physicsnemo.active_learning.QueryStrategy` protocol. Select via
227227
`++acquisition=…`:
228228

@@ -231,6 +231,7 @@ Three strategies ship with this example, all implementing the
231231
| **JointUQQueryStrategy** | `joint_uq` | The recommended UQ-driven default. Per-sample score is `max(|disagreement|, 2 · GP_std)`, where *disagreement* is `|Cd_GP − Cd_field|` — i.e. the gap between the GP-head Cd prediction and the field-integrated Cd recovered from the same forward pass. |
232232
| **RandomQueryStrategy** | `random` | Pure random baseline (uniform over the pool). Use as a UQ-vs-random sanity check. |
233233
| **ClassBalancedRandomQueryStrategy** | `class_balanced_random` | Random *within each class*, with the per-round budget split proportionally. Use as a strong baseline that controls for class imbalance in the pool. |
234+
| **LatentNoveltyQueryStrategy** | `latent_novelty` | Encoder-only novelty signal: each round we calibrate the embedded `OODGuard` (from `physicsnemo.experimental.guardrails.embedded`) on the currently labeled set, then rank unlabeled samples by their average kNN cosine distance in the learned geometry-latent space. Reuses the same guardrail that flags OOD inputs at inference time as the acquisition signal. The first round falls back to class-balanced random because the calibration buffer is empty. |
234235

235236
Adding a new strategy is a matter of subclassing `QueryStrategy` from
236237
`physicsnemo.active_learning.protocols`, implementing
@@ -246,9 +247,10 @@ AL setup involves an oracle that synthesizes labels on demand.
246247

247248
The plot below summarizes a experiment on the ShiftSUV
248249
out-of-distribution dataset (1727 total samples; 181 held out for test,
249-
leaving 1546 in the trainable pool). UQ and class-balanced random both
250-
close the gap between the pretrained DrivAerStar Fastback-only model and a
251-
ShiftSUV full-data ceiling that sees every trainable sample (n = 1546). Pressure
250+
leaving 1546 in the trainable pool). All three acquisition strategies —
251+
joint UQ, class-balanced random, and latent novelty — close the gap
252+
between the pretrained DrivAerStar Fastback-only model and a ShiftSUV
253+
full-data ceiling that sees every trainable sample (n = 1546). Pressure
252254
and wall-shear-stress (WSS) RMS errors are reported in physical units
253255
after un-standardization.
254256

@@ -276,11 +278,11 @@ and wall-shear-stress magnitude. The violins below show how the
276278
**distribution of per-sample correlations** tightens across rounds:
277279
the median moves toward 1.0, the 5th-percentile dashed line catches
278280
up (worst-case samples improve faster than the best-case ones), and
279-
the lower tail of the violin shrinks. Both UQ and class-balanced
280-
random reach median ρ > 0.97 by round 16 (n=160 labels) — well before
281-
the labels-needed thresholds in the table below — meaning the spatial
282-
patterns are already correct long before the absolute RMS hits its
283-
asymptote.
281+
the lower tail of the violin shrinks. All three strategies — UQ,
282+
class-balanced random, and latent novelty — reach median ρ > 0.97 by
283+
round 16 (n=160 labels) — well before the labels-needed thresholds in
284+
the table below — meaning the spatial patterns are already correct
285+
long before the absolute RMS hits its asymptote.
284286

285287
![Per-sample Spearman correlations across AL rounds](../../../../docs/img/al_aero_shiftsuv_correlations.png)
286288

@@ -290,22 +292,26 @@ Numbers from the rightmost panel of the summary plot —
290292
**labels needed to land within X% of the full-data RMS asymptote**
291293
(n_pool = 1546):
292294

293-
| Within X% of ceiling | Joint-UQ labels | Class-bal random labels | Fraction of pool |
294-
|----------------------|-----------------|-------------------------|------------------|
295-
| 100% | 220 | 210 | ~14% |
296-
| 50% | 410 | 390 | ~26% |
297-
| 25% | 650 | 630 | ~41% |
298-
| 10% | 910 | 930 | ~60% |
299-
| 5% | 1040 | 1060 | ~67% |
300-
301-
At the largest budget reached in this run — UQ at n=1120,
302-
BAL at n=1100 — pressure RMS is **15.17 Pa (UQ) / 15.49 Pa (BAL)**
303-
against a full-data ceiling of **14.85 Pa**, i.e. +2.1% / +4.3% above
304-
the asymptote. Read the row as: *"to drive the
305-
surface-field RMS to within 5% of what training on every available
306-
sample would give us, we need to hand-label roughly two-thirds of the
307-
pool"* — and the +25% row as the more frugal *"with ~40% of the labels
308-
we already cut the gap-to-ceiling to a quarter of what it was."*
295+
| Within X% of ceiling | Joint-UQ labels | Class-bal random labels | Latent-novelty labels | Fraction of pool |
296+
|----------------------|-----------------|-------------------------|-----------------------|------------------|
297+
| 100% | 230 | 210 | 210 | ~14% |
298+
| 50% | 430 | 410 | 410 | ~27% |
299+
| 25% | 670 | 670 | 680 | ~43% |
300+
| 10% | 970 | 990 | 980 | ~63% |
301+
| 5% | 1090 | 1120 | 1120 | ~72% |
302+
303+
At the final round of each chain (UQ at n=1150; BAL and LN at n=1140),
304+
pressure RMS is **15.00 Pa (UQ) / 15.30 Pa (BAL) / 15.22 Pa (LN)**
305+
against a full-data ceiling of **14.57 Pa**, i.e. +3.0% / +5.0% / +4.5%
306+
above the asymptote. Read the +5% row as: *"to drive the surface-field
307+
RMS to within 5% of what training on every available sample would give
308+
us, we need to hand-label roughly two-thirds of the pool"* — and the
309+
+25% row as the more frugal *"with ~40% of the labels we already cut
310+
the gap-to-ceiling to a quarter of what it was."* Joint-UQ wins by a
311+
small but consistent margin at every threshold; latent novelty matches
312+
class-balanced random closely without using class labels at
313+
acquisition time, making it a viable drop-in for problems where class
314+
metadata is unavailable.
309315

310316
## Adapting to a new problem
311317

examples/cfd/external_aerodynamics/active_learning_aero/src/conf/al_config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ consistency_detach_transolver: false
5757
al_rounds: 5 # Number of query-train-evaluate rounds
5858
samples_per_round: 50 # Samples selected per round
5959
test_samples_per_class: 100 # Held out for fixed test set (rest goes to pool)
60-
acquisition: "joint_uq" # "joint_uq" or "random"
60+
acquisition: "joint_uq" # "joint_uq" | "random" | "class_balanced_random" | "latent_novelty"
6161
random_seed: 42 # Seed for random baseline (used when acquisition=random)
62+
latent_novelty_knn_k: 10 # k for OODGuard kNN scoring (only used when acquisition=latent_novelty); auto-clamped to labeled-pool size
6263

6364
# --- Fine-tuning schedule ---
6465
fine_tune_epochs: 20 # Epochs per AL round

examples/cfd/external_aerodynamics/active_learning_aero/src/run_al.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
ClassBalancedRandomQueryStrategy,
6969
DummyLabelStrategy,
7070
JointUQQueryStrategy,
71+
LatentNoveltyQueryStrategy,
7172
RandomQueryStrategy,
7273
)
7374
from aero_metrology import FieldMetrologyStrategy
@@ -304,23 +305,33 @@ def main(cfg: DictConfig) -> None:
304305
)
305306

306307
# ---- Strategies ----
307-
if acquisition == "joint_uq":
308-
query_strategy = JointUQQueryStrategy(
309-
max_samples=samples_per_round, precision=precision
310-
)
311-
elif acquisition == "class_balanced_random":
312-
query_strategy = ClassBalancedRandomQueryStrategy(
313-
max_samples=samples_per_round, seed=random_seed
314-
)
315-
elif acquisition == "random":
316-
query_strategy = RandomQueryStrategy(
317-
max_samples=samples_per_round, seed=random_seed
318-
)
319-
else:
320-
raise ValueError(
321-
f"Unknown acquisition strategy: {acquisition!r}. "
322-
f"Expected one of: 'joint_uq', 'random', 'class_balanced_random'."
323-
)
308+
match acquisition:
309+
case "joint_uq":
310+
query_strategy = JointUQQueryStrategy(
311+
max_samples=samples_per_round, precision=precision
312+
)
313+
case "class_balanced_random":
314+
query_strategy = ClassBalancedRandomQueryStrategy(
315+
max_samples=samples_per_round, seed=random_seed
316+
)
317+
case "random":
318+
query_strategy = RandomQueryStrategy(
319+
max_samples=samples_per_round, seed=random_seed
320+
)
321+
case "latent_novelty":
322+
knn_k = int(getattr(cfg, "latent_novelty_knn_k", 10))
323+
query_strategy = LatentNoveltyQueryStrategy(
324+
max_samples=samples_per_round,
325+
precision=precision,
326+
knn_k=knn_k,
327+
cold_start_seed=random_seed,
328+
)
329+
case _:
330+
raise ValueError(
331+
f"Unknown acquisition strategy: {acquisition!r}. "
332+
f"Expected one of: 'joint_uq', 'random', 'class_balanced_random', "
333+
f"'latent_novelty'."
334+
)
324335

325336
metrology = FieldMetrologyStrategy(precision=precision)
326337
label_strategy = DummyLabelStrategy()

0 commit comments

Comments
 (0)