Skip to content

Commit d89a70e

Browse files
committed
Address reviewer feedback: rollout bug, dead amp field, _step_ensemble indirection
- Fix inference _rollout: torch.cat([history[:, 1:], next_frame.unsqueeze(1)]) so history window slides correctly for any history_frames value, not just 2 - Remove unimplemented amp config field from TrainingConfig and default.yaml - Inline model call in _loss AR loop instead of routing through _step_ensemble with num_samples=1 (each member needs its own history, so the single-call collapse doesn't apply; direct call is cleaner) Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
1 parent b1e9152 commit d89a70e

4 files changed

Lines changed: 16 additions & 11 deletions

File tree

examples/weather/fgn/config/training/default.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ checkpoint_freq: 50
1212
validation_freq: 25
1313
resume_checkpoint: latest
1414
clip_grad_norm: -1.0
15-
amp: false
1615

1716
# Autoregressive rollout depth per training step.
1817
# Must match dataset.future_frames. 1 = single-step training (paper Stages

examples/weather/fgn/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def _rollout(
106106
next_frame = next_frame.clone()
107107
for ci in output_only_channels:
108108
next_frame[:, ci].zero_()
109-
rollout_history = torch.stack(
110-
[rollout_history[:, 1], next_frame],
109+
rollout_history = torch.cat(
110+
[rollout_history[:, 1:], next_frame.unsqueeze(1)],
111111
dim=1,
112112
)
113113
trajectories.append(torch.stack(states, dim=1))

examples/weather/fgn/utils/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class TrainingConfig:
5858
validation_freq: int = Field(default=25, ge=1)
5959
resume_checkpoint: int | Literal["latest"] | None = "latest"
6060
clip_grad_norm: float = -1.0
61-
amp: bool = False
6261
ar_steps: int = Field(default=1, ge=1, le=8)
6362
# Data + domain parallelism knobs. Mirrors StormCast's convention.
6463
# - domain_parallel_size=1 & force_sharding=False → pure single-process

examples/weather/fgn/utils/trainer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,20 @@ def _loss(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
320320
members = []
321321
for n in range(num_samples):
322322
hist_n = per_member_hist[:, n]
323-
pred_n = self._step_ensemble(
324-
history=hist_n,
325-
background=background,
326-
invariants=invariants,
327-
num_samples=1,
328-
)[:, 0]
329-
members.append(pred_n)
323+
latent = torch.randn(
324+
hist_n.shape[0],
325+
int(self.cfg.model.latent_dim),
326+
device=self.device,
327+
dtype=torch.float32,
328+
)
329+
members.append(
330+
self.model(
331+
history=hist_n,
332+
latent=latent,
333+
background=background,
334+
invariants=invariants,
335+
)
336+
)
330337
preds = torch.stack(members, dim=1) # (B, N, C, H, W)
331338

332339
step_loss = fair_crps(preds, target[:, k], weights=self.loss_weights)

0 commit comments

Comments
 (0)