Skip to content

Commit ea4e1c2

Browse files
authored
Reintroduce Atlas with performance flag (#897)
1 parent 1fbf9d6 commit ea4e1c2

3 files changed

Lines changed: 10 additions & 0 deletions

File tree

earth2studio/models/px/atlas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ class Atlas(torch.nn.Module, AutoModelMixin, PrognosticMixin):
162162
performs autoregressive timestepping using a full-resolution physical state
163163
and an internal low-resolution latent state.
164164
165+
Note
166+
----
167+
For best inference performance, set the environment variable `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1`.
168+
This is on by default in NGC containers, but other environments may need to set it manually.
169+
165170
Badges
166171
------
167172
region:global class:mrf product:wind product:precip product:temp product:atmos year:2026

examples/02_medium_range/06_atlas_inference.disabled renamed to examples/02_medium_range/06_atlas_inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@
6363
import os
6464

6565
os.makedirs("outputs", exist_ok=True)
66+
# Performance optimization for Atlas model
67+
# This is on by default in NGC containers, but other environments may need to set it manually.
68+
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] = "1"
6669
from dotenv import load_dotenv
6770

6871
load_dotenv() # TODO: make common example prep function

tox.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ setenv =
3939
# https://docs.jax.dev/en/latest/gpu_memory_allocation.html
4040
XLA_PYTHON_CLIENT_ALLOCATOR = platform
4141
XLA_PYTHON_CLIENT_PREALLOCATE = false
42+
# Performance optimization for Atlas model
43+
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE = 1
4244
# Unique coverage file per tox env so parallel CI jobs don't overwrite each other
4345
COVERAGE_FILE = .coverage.{envname}
4446

0 commit comments

Comments
 (0)