@@ -76,6 +76,10 @@ conflicts = [
7676 { extra = " natten-cu12" },
7777 { extra = " natten-cu13" },
7878 ],
79+ [
80+ { extra = " transformer-engine-cu12" },
81+ { extra = " transformer-engine-cu13" },
82+ ],
7983]
8084
8185[tool .uv .extra-build-dependencies ]
@@ -152,6 +156,28 @@ natten = [
152156 { index = " natten-cu130-whl" , extra = " natten-cu13" },
153157]
154158
159+ # transformer-engine-torch is an sdist whose setup.py dynamically pins
160+ # transformer_engine_cu{12,13} based on torch.version.cuda in the build
161+ # environment. Without this override, uv lock freezes whichever CUDA
162+ # variant happened to be in the build env, which broke CI when PyTorch
163+ # 2.12 flipped the PyPI default to cu130 (the cu13 binary then got
164+ # pulled into a CUDA 12.8 runtime, producing missing libcublas.so.13
165+ # errors at import time). We provide the CUDA-agnostic subset of the
166+ # dependencies here; the matching cuXX binary is pulled by the explicit
167+ # transformer-engine-cu12 / transformer-engine-cu13 extras below.
168+ # Static deps mirror TE 2.15's build_tools/pytorch.py::install_requirements().
169+ [[tool .uv .dependency-metadata ]]
170+ name = " transformer-engine-torch"
171+ requires-dist = [
172+ " torch>=2.1" ,
173+ " einops" ,
174+ " onnxscript" ,
175+ " onnx" ,
176+ " packaging" ,
177+ " pydantic" ,
178+ " nvdlfw-inspect" ,
179+ ]
180+
155181# ####################################################################
156182# Flags Controlling the local build of physicsnemo
157183# ####################################################################
@@ -248,6 +274,21 @@ natten-cu12 = [
248274natten-cu13 = [
249275 " natten>=0.21.5" ,
250276]
277+ # Transformer Engine extras (mutually exclusive via [tool.uv] conflicts).
278+ # Pair with the matching CUDA backend extra, e.g.:
279+ # pip install nvidia-physicsnemo[cu12,transformer-engine-cu12]
280+ # pip install nvidia-physicsnemo[cu13,transformer-engine-cu13]
281+ # Uses the metapackage's `core_cu12` / `core_cu13` extras (introduced in
282+ # transformer-engine 2.14.0) to explicitly pin the CUDA backend binary,
283+ # so the resolution doesn't depend on torch.version.cuda in the build
284+ # env. See [[tool.uv.dependency-metadata]] above for the matching
285+ # override on transformer-engine-torch's dynamic CUDA dep.
286+ transformer-engine-cu12 = [
287+ " transformer_engine[pytorch,core_cu12]>=2.14.0" ,
288+ ]
289+ transformer-engine-cu13 = [
290+ " transformer_engine[pytorch,core_cu13]>=2.14.0" ,
291+ ]
251292utils-extras = [
252293 " wandb" ,
253294 " mlflow>=3.12.0" ,
@@ -295,10 +336,6 @@ sym = [
295336 " sympy>=1.12" ,
296337]
297338
298- perf = [
299- " transformer_engine[pytorch]" ,
300- ]
301-
302339
303340# ####################################################################
304341# Linting configuration
0 commit comments