Skip to content

Commit edcde50

Browse files
committed
Contribution standards: ruff clean, CHANGELOG, FGNUNet docstring
- Run ruff format + fix across all fgn/ Python files - Remove unused imports (Sequence, Callable, ShardTensor, math, torch) - Replace assert with if/raise (S101), fix import order (I001), simplify loops to list-comprehension/extend (PERF401/102) - Add noqa: E402 on intentional post-path-insert imports in stage4 - Upgrade FGNUNet docstring to MOD-003 (r-string, NumPy sections, Parameters/Forward/Outputs with LaTeX shapes, Examples) - Add CHANGELOG.md entry under [2.1.0a0] Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
1 parent d89a70e commit edcde50

15 files changed

Lines changed: 282 additions & 134 deletions

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
140140
- Added support for Batched radius search, which enables Domino
141141
and GeoTransolver with local features and batch size > 1.
142142
- Added the underfill recipe.
143+
- Adds Functional Generative Networks (FGN) weather training example
144+
(`examples/weather/fgn`). Implements the latent-conditioned U-Net
145+
stochastic generator from
146+
`arXiv:2506.10772 <https://arxiv.org/abs/2506.10772>`_ (WeatherNext 2)
147+
as a PhysicsNeMo ``Module``, trained with fair-CRPS loss on ERA5 via the
148+
earth2studio ARCO data source. Supports autoregressive rollout training
149+
with per-channel normalization, FSDP + ShardTensor domain parallelism,
150+
deep-ensemble inference (paper §2.2.1), and validation diagnostics
151+
(CRPS, RMSE, spread-skill, rank histograms, power spectra).
143152

144153
### Changed
145154

examples/weather/fgn/datasets/arco.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
from datetime import datetime, timedelta
17-
from typing import Any, Sequence
17+
from typing import Any
1818

1919
import numpy as np
2020
import torch
@@ -24,7 +24,19 @@
2424
# Paper Table A.1 atmospheric schema
2525
PAPER_ATMOS_VARS: tuple[str, ...] = ("z", "q", "t", "u", "v", "w")
2626
PAPER_LEVELS: tuple[int, ...] = (
27-
50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000,
27+
50,
28+
100,
29+
150,
30+
200,
31+
250,
32+
300,
33+
400,
34+
500,
35+
600,
36+
700,
37+
850,
38+
925,
39+
1000,
2840
)
2941
PAPER_SURFACE_IN_OUT: tuple[str, ...] = ("t2m", "u10m", "v10m", "msl", "sst")
3042

@@ -96,8 +108,14 @@ class ArcoFGNDataset(FGNDataset):
96108

97109
def __init__(self, params: Any, train: bool) -> None:
98110
def _get(name: str, default: Any) -> Any:
99-
return getattr(params, name, default) if hasattr(params, name) else (
100-
params[name] if isinstance(params, dict) and name in params else default
111+
return (
112+
getattr(params, name, default)
113+
if hasattr(params, name)
114+
else (
115+
params[name]
116+
if isinstance(params, dict) and name in params
117+
else default
118+
)
101119
)
102120

103121
state = _get("state_variables", None)
@@ -232,9 +250,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
232250
# slot from _fetch_tp_accumulation below.
233251
if self._tp_channel_idx is not None:
234252
ci = self._tp_channel_idx
235-
arco_vars = [
236-
v for i, v in enumerate(self._state_variables) if i != ci
237-
]
253+
arco_vars = [v for i, v in enumerate(self._state_variables) if i != ci]
238254
else:
239255
ci = None
240256
arco_vars = list(self._state_variables)
@@ -368,7 +384,10 @@ def _fetch_tp_accumulation(self, frame_times: list[datetime]) -> np.ndarray:
368384
fetch all distinct hourly stamps required by any frame in a single
369385
earth2studio call to minimise GCS round-trips.
370386
"""
371-
assert self.tp_accumulation_hours is not None
387+
if self.tp_accumulation_hours is None:
388+
raise RuntimeError(
389+
"_fetch_tp_accumulation called without tp_accumulation_hours set"
390+
)
372391
N = self.tp_accumulation_hours
373392

374393
# Union of hours we need across all frames, sorted.
@@ -387,9 +406,7 @@ def _fetch_tp_accumulation(self, frame_times: list[datetime]) -> np.ndarray:
387406
if self.stride > 1:
388407
hourly = hourly[:, :: self.stride, :: self.stride]
389408

390-
acc = np.zeros(
391-
(len(frame_times), self.height, self.width), dtype=np.float32
392-
)
409+
acc = np.zeros((len(frame_times), self.height, self.width), dtype=np.float32)
393410
for k, hours_k in enumerate(per_frame_hours):
394411
acc[k] = sum(hourly[hour_to_idx[h]] for h in hours_k)
395412
return acc
@@ -418,9 +435,7 @@ def _load_stats(self, stats_path: str) -> None:
418435
self._mean = mean
419436
self._std = std
420437

421-
def _broadcast_stats_for(
422-
self, x: np.ndarray | torch.Tensor
423-
) -> tuple[Any, Any]:
438+
def _broadcast_stats_for(self, x: np.ndarray | torch.Tensor) -> tuple[Any, Any]:
424439
"""Reshape `(V,)` stats to broadcast along the channel axis of ``x``.
425440
426441
Supports ``x`` of shape ``(V, H, W)``, ``(T, V, H, W)``, or
@@ -435,8 +450,12 @@ def _broadcast_stats_for(
435450
else:
436451
raise ValueError(f"unsupported state tensor ndim {x.ndim}")
437452
if isinstance(x, torch.Tensor):
438-
mean = torch.as_tensor(self._mean, dtype=x.dtype, device=x.device).reshape(shape)
439-
std = torch.as_tensor(self._std, dtype=x.dtype, device=x.device).reshape(shape)
453+
mean = torch.as_tensor(self._mean, dtype=x.dtype, device=x.device).reshape(
454+
shape
455+
)
456+
std = torch.as_tensor(self._std, dtype=x.dtype, device=x.device).reshape(
457+
shape
458+
)
440459
else:
441460
mean = self._mean.reshape(shape)
442461
std = self._std.reshape(shape)

examples/weather/fgn/inference.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@
1414
from pathlib import Path
1515

1616
import hydra
17-
from omegaconf import DictConfig
1817
import torch
18+
from datasets import dataset_classes
19+
from omegaconf import DictConfig
20+
from utils.trainer import find_latest_model_checkpoint
1921

2022
from physicsnemo.core import Module
2123
from physicsnemo.distributed import DistributedManager
2224

23-
from datasets import dataset_classes
24-
from utils.trainer import find_latest_model_checkpoint
25-
2625

2726
def _resolve_checkpoints(cfg: DictConfig) -> list[str]:
2827
"""Resolve the inference config to an ordered list of checkpoint paths.
@@ -31,9 +30,11 @@ def _resolve_checkpoints(cfg: DictConfig) -> list[str]:
3130
back to ``inference.checkpoint`` (single path or ``"latest"``). Single-
3231
model inference is just the length-1 deep-ensemble case.
3332
"""
34-
checkpoints = cfg.inference.get("checkpoints", None) if hasattr(
35-
cfg.inference, "get"
36-
) else getattr(cfg.inference, "checkpoints", None)
33+
checkpoints = (
34+
cfg.inference.get("checkpoints", None)
35+
if hasattr(cfg.inference, "get")
36+
else getattr(cfg.inference, "checkpoints", None)
37+
)
3738
if checkpoints:
3839
return [str(c) for c in checkpoints]
3940

@@ -139,19 +140,25 @@ def run_inference(cfg: DictConfig) -> dict[str, float | str | int | list[int]]:
139140
# only uses the first step, so collapse to (B, C, H, W).
140141
if target.ndim == 5:
141142
target = target[:, 0]
142-
background = sample["background"].unsqueeze(0).to(device=device, dtype=torch.float32)
143+
background = (
144+
sample["background"].unsqueeze(0).to(device=device, dtype=torch.float32)
145+
)
143146

144147
invariants = dataset.get_invariants()
145148
if invariants is not None:
146-
invariants = torch.from_numpy(invariants).unsqueeze(0).to(
147-
device=device, dtype=torch.float32
149+
invariants = (
150+
torch.from_numpy(invariants)
151+
.unsqueeze(0)
152+
.to(device=device, dtype=torch.float32)
148153
)
149154

150155
all_trajectories: list[torch.Tensor] = []
151156
num_steps = int(cfg.inference.num_steps)
152157
output_only = dataset.output_only_channels()
153158
with torch.no_grad():
154-
for ckpt_path, n_members in zip(checkpoint_paths, members_per_model, strict=True):
159+
for ckpt_path, n_members in zip(
160+
checkpoint_paths, members_per_model, strict=True
161+
):
155162
if n_members <= 0:
156163
continue
157164
model = Module.from_checkpoint(ckpt_path).to(device).eval()

examples/weather/fgn/scripts/compute_arco_stats.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,19 @@
3131
# Paper Table A.1 defaults -- keep in sync with datasets.arco.DEFAULT_STATE.
3232
DEFAULT_ATMOS_VARS = ("z", "q", "t", "u", "v", "w")
3333
DEFAULT_LEVELS = (
34-
50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000,
34+
50,
35+
100,
36+
150,
37+
200,
38+
250,
39+
300,
40+
400,
41+
500,
42+
600,
43+
700,
44+
850,
45+
925,
46+
1000,
3547
)
3648
DEFAULT_SURFACE = ("t2m", "u10m", "v10m", "msl", "sst")
3749
DEFAULT_STATE = tuple(
@@ -50,43 +62,61 @@ def parse_args() -> argparse.Namespace:
5062
"arrays to be laid out. Default: paper Table A.1 (83 channels).",
5163
)
5264
p.add_argument(
53-
"--start", type=str, default="1979-01-01",
65+
"--start",
66+
type=str,
67+
default="1979-01-01",
5468
help="Earliest sample time (ISO date).",
5569
)
5670
p.add_argument(
57-
"--end", type=str, default="2018-01-01",
71+
"--end",
72+
type=str,
73+
default="2018-01-01",
5874
help="Latest sample time (ISO date, exclusive).",
5975
)
6076
p.add_argument(
61-
"--step-hours", type=int, default=6,
77+
"--step-hours",
78+
type=int,
79+
default=6,
6280
help="Sampling cadence; restricts timestamps to a 6h grid by default.",
6381
)
6482
p.add_argument(
65-
"--samples", type=int, default=256,
83+
"--samples",
84+
type=int,
85+
default=256,
6686
help="Number of random timestamps to average over.",
6787
)
6888
p.add_argument(
69-
"--stride", type=int, default=1,
89+
"--stride",
90+
type=int,
91+
default=1,
7092
help="Spatial stride applied to the 721x1440 grid to cut fetch cost.",
7193
)
7294
p.add_argument(
73-
"--tp-accumulation-hours", type=int, default=None,
95+
"--tp-accumulation-hours",
96+
type=int,
97+
default=None,
7498
help="If set to N, any variable named tp{N:02d} (e.g. tp06 for N=6) "
75-
"in --variables is treated as a paper §3 N-hour accumulation of "
76-
"ARCO hourly ``tp``, not a native ARCOLexicon key. Matches "
77-
"ArcoFGNDataset.tp_accumulation_hours so stats and training see "
78-
"the same representation.",
99+
"in --variables is treated as a paper §3 N-hour accumulation of "
100+
"ARCO hourly ``tp``, not a native ARCOLexicon key. Matches "
101+
"ArcoFGNDataset.tp_accumulation_hours so stats and training see "
102+
"the same representation.",
79103
)
80104
p.add_argument("--seed", type=int, default=0)
81105
p.add_argument(
82-
"--output", type=Path, required=True,
106+
"--output",
107+
type=Path,
108+
required=True,
83109
help="Destination .npz path.",
84110
)
85111
return p.parse_args()
86112

87113

88114
def iter_sample_times(
89-
start: datetime, end: datetime, step_hours: int, samples: int, rng: np.random.Generator
115+
start: datetime,
116+
end: datetime,
117+
step_hours: int,
118+
samples: int,
119+
rng: np.random.Generator,
90120
):
91121
total_hours = int((end - start).total_seconds() // 3600)
92122
max_offset = total_hours // step_hours

examples/weather/fgn/scripts/prefetch_arco.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from __future__ import annotations
3737

3838
import argparse
39+
import sys
3940
from datetime import datetime, timedelta
4041
from pathlib import Path
41-
import sys
4242

4343
_EXAMPLE_DIR = Path(__file__).resolve().parents[1]
4444
if str(_EXAMPLE_DIR) not in sys.path:
@@ -47,12 +47,18 @@
4747

4848
def parse_args() -> argparse.Namespace:
4949
p = argparse.ArgumentParser(description=__doc__)
50-
p.add_argument("--start", default="2024-01-01", help="Window start (ISO, inclusive).")
50+
p.add_argument(
51+
"--start", default="2024-01-01", help="Window start (ISO, inclusive)."
52+
)
5153
p.add_argument("--end", default="2025-01-01", help="Window end (ISO, exclusive).")
5254
p.add_argument("--step-hours", type=int, default=6)
5355
p.add_argument("--history-frames", type=int, default=2)
54-
p.add_argument("--tp-accumulation-hours", type=int, default=6,
55-
help="N for tp{N:02d} accumulation. 0 = skip tp fetch.")
56+
p.add_argument(
57+
"--tp-accumulation-hours",
58+
type=int,
59+
default=6,
60+
help="N for tp{N:02d} accumulation. 0 = skip tp fetch.",
61+
)
5662
p.add_argument(
5763
"--variables",
5864
nargs="+",
@@ -65,20 +71,28 @@ def parse_args() -> argparse.Namespace:
6571
help="Date for one-off invariants fetch (z, lsm).",
6672
)
6773
p.add_argument(
68-
"--batch-days", type=int, default=31,
74+
"--batch-days",
75+
type=int,
76+
default=31,
6977
help="Days of data per time-batch. Default: 31.",
7078
)
7179
p.add_argument(
72-
"--var-group-size", type=int, default=10,
80+
"--var-group-size",
81+
type=int,
82+
default=10,
7383
help="Variables per sub-batch. Limits concurrent GCS requests to "
74-
"batch_days_timestamps × var_group_size. Default: 10.",
84+
"batch_days_timestamps × var_group_size. Default: 10.",
7585
)
7686
p.add_argument(
77-
"--no-tp", dest="include_tp", action="store_false",
87+
"--no-tp",
88+
dest="include_tp",
89+
action="store_false",
7890
help="Skip hourly tp prefetch.",
7991
)
8092
p.add_argument(
81-
"--no-invariants", dest="include_invariants", action="store_false",
93+
"--no-invariants",
94+
dest="include_invariants",
95+
action="store_false",
8296
help="Skip invariants prefetch.",
8397
)
8498
p.set_defaults(include_tp=True, include_invariants=True)
@@ -94,7 +108,9 @@ def _window_times(start: datetime, end: datetime, step_hours: int) -> list[datet
94108
return out
95109

96110

97-
def _batch(times: list[datetime], batch_days: int, step_hours: int = 1) -> list[list[datetime]]:
111+
def _batch(
112+
times: list[datetime], batch_days: int, step_hours: int = 1
113+
) -> list[list[datetime]]:
98114
"""Split a time list into chunks covering at most batch_days of real time."""
99115
n = max(1, batch_days * 24 // step_hours)
100116
return [times[i : i + n] for i in range(0, len(times), n)]
@@ -104,7 +120,9 @@ def main() -> int:
104120
args = parse_args()
105121
# Silence earth2studio's per-fetch DEBUG lines — they flood log files at scale.
106122
import sys
123+
107124
from loguru import logger
125+
108126
logger.remove()
109127
logger.add(sys.stderr, level="INFO")
110128

@@ -169,7 +187,7 @@ def main() -> int:
169187
f"[state {ti}/{n_tb} vars {vi}/{n_vg}] "
170188
f"{batch_times[0].date()}{batch_times[-1].date()}"
171189
f" ({len(batch_times)} steps × {len(vgroup)} vars"
172-
f" = {len(batch_times)*len(vgroup)} requests)"
190+
f" = {len(batch_times) * len(vgroup)} requests)"
173191
)
174192
arco(time=batch_times, variable=vgroup)
175193

examples/weather/fgn/scripts/stage4_ar_schedule.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
if str(_EXAMPLE_DIR) not in sys.path:
5454
sys.path.insert(0, str(_EXAMPLE_DIR))
5555

56-
from hydra import compose, initialize
57-
from omegaconf import DictConfig, OmegaConf
56+
from hydra import compose, initialize # noqa: E402
57+
from omegaconf import DictConfig, OmegaConf # noqa: E402
5858

5959
# Stage 4 of paper Table A.2.
6060
PAPER_STAGES: list[dict] = [
@@ -143,7 +143,6 @@ def _seed_from_prev_stage(
143143
stage_checkpoint_dir.mkdir(parents=True, exist_ok=True)
144144
shutil.copy2(last_mdlus, stage_checkpoint_dir / last_mdlus.name)
145145
# Copy optimizer/scheduler state if present so the resume is exact.
146-
opt_pt = last_mdlus.with_suffix("").with_suffix(".pt")
147146
# physicsnemo names these ``checkpoint.{mp_rank}.{epoch}.pt``; find the
148147
# matching epoch by filename suffix (``.<epoch>.mdlus``).
149148
epoch_suffix = last_mdlus.stem.split(".")[-1]

0 commit comments

Comments
 (0)