Skip to content

Commit 20f6bb1

Browse files
authored
Enhance dependency management in pyproject.toml and CI workflows by adding transformer-engine extras for CUDA 12 and 13. Update EXTRAS_TAG in CI configurations to include transformer-engine-cu12, ensuring compatibility with the latest dependencies. This change addresses potential issues with dynamic CUDA dependency resolution in transformer-engine-torch. (#1683)
1 parent f9c4513 commit 20f6bb1

5 files changed

Lines changed: 1610 additions & 2108 deletions

File tree

.github/regen-ci-deps-lock.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ set -euo pipefail
4343

4444
# These MUST match the workflow `env:` values in
4545
# .github/workflows/github-{pr,nightly-uv}.yml. Bump in lockstep.
46-
EXTRAS_TAG="${EXTRAS_TAG:-cu12,natten-cu12,utils-extras,mesh-extras,nn-extras,model-extras,datapipes-extras,uq-extras,gnns,sym,perf}"
46+
EXTRAS_TAG="${EXTRAS_TAG:-cu12,natten-cu12,utils-extras,mesh-extras,nn-extras,model-extras,datapipes-extras,uq-extras,gnns,sym,transformer-engine-cu12}"
4747
UV_VERSION="${UV_VERSION:-0.11.7}"
4848
# Matches the `--find-links` URL committed to .github/ci-requirements.txt.
4949
# Bump the torch-X.Y.Z+cu128 segment in lockstep with the locked torch version.

.github/workflows/github-nightly-uv.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ env:
8686
# with no extras home (moto, scikit-image, pyg_lib, earth2grid, ...) are
8787
# installed by the `Install CI-only test dependencies` step inside the
8888
# setup-uv-env composite action.
89-
EXTRAS_TAG: "cu12,natten-cu12,utils-extras,mesh-extras,nn-extras,model-extras,datapipes-extras,uq-extras,gnns,sym,perf"
89+
EXTRAS_TAG: "cu12,natten-cu12,utils-extras,mesh-extras,nn-extras,model-extras,datapipes-extras,uq-extras,gnns,sym,transformer-engine-cu12"
9090

9191
# ---- Cache key prefixes ------------------------------------------------
9292
# Inlined literally because GitHub Actions does not allow env-to-env

.github/workflows/github-pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ env:
6363
# with no extras home (moto, scikit-image, pyg_lib, earth2grid, ...) are
6464
# installed by the `Install CI-only test dependencies` step inside the
6565
# setup-uv-env composite action.
66-
EXTRAS_TAG: "cu12,natten-cu12,utils-extras,mesh-extras,nn-extras,model-extras,datapipes-extras,uq-extras,gnns,sym,perf"
66+
EXTRAS_TAG: "cu12,natten-cu12,utils-extras,mesh-extras,nn-extras,model-extras,datapipes-extras,uq-extras,gnns,sym,transformer-engine-cu12"
6767

6868
# ---- Cache key prefixes (shared with nightly) --------------------------
6969
# The `-fullextras` suffix is bumped relative to the previous prefix so

pyproject.toml

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ conflicts = [
7676
{ extra = "natten-cu12" },
7777
{ extra = "natten-cu13" },
7878
],
79+
[
80+
{ extra = "transformer-engine-cu12" },
81+
{ extra = "transformer-engine-cu13" },
82+
],
7983
]
8084

8185
[tool.uv.extra-build-dependencies]
@@ -152,6 +156,28 @@ natten = [
152156
{ index = "natten-cu130-whl", extra = "natten-cu13" },
153157
]
154158

159+
# transformer-engine-torch is an sdist whose setup.py dynamically pins
160+
# transformer_engine_cu{12,13} based on torch.version.cuda in the build
161+
# environment. Without this override, uv lock freezes whichever CUDA
162+
# variant happened to be in the build env, which broke CI when PyTorch
163+
# 2.12 flipped the PyPI default to cu130 (the cu13 binary then got
164+
# pulled into a CUDA 12.8 runtime, producing missing libcublas.so.13
165+
# errors at import time). We provide the CUDA-agnostic subset of the
166+
# dependencies here; the matching cuXX binary is pulled by the explicit
167+
# transformer-engine-cu12 / transformer-engine-cu13 extras below.
168+
# Static deps mirror TE 2.15's build_tools/pytorch.py::install_requirements().
169+
[[tool.uv.dependency-metadata]]
170+
name = "transformer-engine-torch"
171+
requires-dist = [
172+
"torch>=2.1",
173+
"einops",
174+
"onnxscript",
175+
"onnx",
176+
"packaging",
177+
"pydantic",
178+
"nvdlfw-inspect",
179+
]
180+
155181
#####################################################################
156182
# Flags Controlling the local build of physicsnemo
157183
#####################################################################
@@ -248,6 +274,21 @@ natten-cu12 = [
248274
natten-cu13 = [
249275
"natten>=0.21.5",
250276
]
277+
# Transformer Engine extras (mutually exclusive via [tool.uv] conflicts).
278+
# Pair with the matching CUDA backend extra, e.g.:
279+
# pip install nvidia-physicsnemo[cu12,transformer-engine-cu12]
280+
# pip install nvidia-physicsnemo[cu13,transformer-engine-cu13]
281+
# Uses the metapackage's `core_cu12` / `core_cu13` extras (introduced in
282+
# transformer-engine 2.14.0) to explicitly pin the CUDA backend binary,
283+
# so the resolution doesn't depend on torch.version.cuda in the build
284+
# env. See [[tool.uv.dependency-metadata]] above for the matching
285+
# override on transformer-engine-torch's dynamic CUDA dep.
286+
transformer-engine-cu12 = [
287+
"transformer_engine[pytorch,core_cu12]>=2.14.0",
288+
]
289+
transformer-engine-cu13 = [
290+
"transformer_engine[pytorch,core_cu13]>=2.14.0",
291+
]
251292
utils-extras = [
252293
"wandb",
253294
"mlflow>=3.12.0",
@@ -295,10 +336,6 @@ sym = [
295336
"sympy>=1.12",
296337
]
297338

298-
perf = [
299-
"transformer_engine[pytorch]",
300-
]
301-
302339

303340
#####################################################################
304341
# Linting configuration

0 commit comments

Comments
 (0)