1414from __future__ import annotations
1515
1616from datetime import datetime , timedelta
17- from typing import Any , Sequence
17+ from typing import Any
1818
1919import numpy as np
2020import torch
2424# Paper Table A.1 atmospheric schema
2525PAPER_ATMOS_VARS : tuple [str , ...] = ("z" , "q" , "t" , "u" , "v" , "w" )
2626PAPER_LEVELS : tuple [int , ...] = (
27- 50 , 100 , 150 , 200 , 250 , 300 , 400 , 500 , 600 , 700 , 850 , 925 , 1000 ,
27+ 50 ,
28+ 100 ,
29+ 150 ,
30+ 200 ,
31+ 250 ,
32+ 300 ,
33+ 400 ,
34+ 500 ,
35+ 600 ,
36+ 700 ,
37+ 850 ,
38+ 925 ,
39+ 1000 ,
2840)
2941PAPER_SURFACE_IN_OUT : tuple [str , ...] = ("t2m" , "u10m" , "v10m" , "msl" , "sst" )
3042
@@ -96,8 +108,14 @@ class ArcoFGNDataset(FGNDataset):
96108
97109 def __init__ (self , params : Any , train : bool ) -> None :
98110 def _get (name : str , default : Any ) -> Any :
99- return getattr (params , name , default ) if hasattr (params , name ) else (
100- params [name ] if isinstance (params , dict ) and name in params else default
111+ return (
112+ getattr (params , name , default )
113+ if hasattr (params , name )
114+ else (
115+ params [name ]
116+ if isinstance (params , dict ) and name in params
117+ else default
118+ )
101119 )
102120
103121 state = _get ("state_variables" , None )
@@ -232,9 +250,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
232250 # slot from _fetch_tp_accumulation below.
233251 if self ._tp_channel_idx is not None :
234252 ci = self ._tp_channel_idx
235- arco_vars = [
236- v for i , v in enumerate (self ._state_variables ) if i != ci
237- ]
253+ arco_vars = [v for i , v in enumerate (self ._state_variables ) if i != ci ]
238254 else :
239255 ci = None
240256 arco_vars = list (self ._state_variables )
@@ -368,7 +384,10 @@ def _fetch_tp_accumulation(self, frame_times: list[datetime]) -> np.ndarray:
368384 fetch all distinct hourly stamps required by any frame in a single
369385 earth2studio call to minimise GCS round-trips.
370386 """
371- assert self .tp_accumulation_hours is not None
387+ if self .tp_accumulation_hours is None :
388+ raise RuntimeError (
389+ "_fetch_tp_accumulation called without tp_accumulation_hours set"
390+ )
372391 N = self .tp_accumulation_hours
373392
374393 # Union of hours we need across all frames, sorted.
@@ -387,9 +406,7 @@ def _fetch_tp_accumulation(self, frame_times: list[datetime]) -> np.ndarray:
387406 if self .stride > 1 :
388407 hourly = hourly [:, :: self .stride , :: self .stride ]
389408
390- acc = np .zeros (
391- (len (frame_times ), self .height , self .width ), dtype = np .float32
392- )
409+ acc = np .zeros ((len (frame_times ), self .height , self .width ), dtype = np .float32 )
393410 for k , hours_k in enumerate (per_frame_hours ):
394411 acc [k ] = sum (hourly [hour_to_idx [h ]] for h in hours_k )
395412 return acc
@@ -418,9 +435,7 @@ def _load_stats(self, stats_path: str) -> None:
418435 self ._mean = mean
419436 self ._std = std
420437
421- def _broadcast_stats_for (
422- self , x : np .ndarray | torch .Tensor
423- ) -> tuple [Any , Any ]:
438+ def _broadcast_stats_for (self , x : np .ndarray | torch .Tensor ) -> tuple [Any , Any ]:
424439 """Reshape `(V,)` stats to broadcast along the channel axis of ``x``.
425440
426441 Supports ``x`` of shape ``(V, H, W)``, ``(T, V, H, W)``, or
@@ -435,8 +450,12 @@ def _broadcast_stats_for(
435450 else :
436451 raise ValueError (f"unsupported state tensor ndim { x .ndim } " )
437452 if isinstance (x , torch .Tensor ):
438- mean = torch .as_tensor (self ._mean , dtype = x .dtype , device = x .device ).reshape (shape )
439- std = torch .as_tensor (self ._std , dtype = x .dtype , device = x .device ).reshape (shape )
453+ mean = torch .as_tensor (self ._mean , dtype = x .dtype , device = x .device ).reshape (
454+ shape
455+ )
456+ std = torch .as_tensor (self ._std , dtype = x .dtype , device = x .device ).reshape (
457+ shape
458+ )
440459 else :
441460 mean = self ._mean .reshape (shape )
442461 std = self ._std .reshape (shape )
0 commit comments