@@ -77,6 +77,10 @@ conflicts = [
7777 { extra = " natten-cu12" },
7878 { extra = " natten-cu13" },
7979 ],
80+ [
81+ { extra = " transformer-engine-cu12" },
82+ { extra = " transformer-engine-cu13" },
83+ ],
8084]
8185
8286[tool .uv .extra-build-dependencies ]
@@ -153,6 +157,28 @@ natten = [
153157 { index = " natten-cu130-whl" , extra = " natten-cu13" },
154158]
155159
160+ # transformer-engine-torch is an sdist whose setup.py dynamically pins
161+ # transformer_engine_cu{12,13} based on torch.version.cuda in the build
162+ # environment. Without this override, uv lock freezes whichever CUDA
163+ # variant happened to be in the build env, which broke CI when PyTorch
164+ # 2.12 flipped the PyPI default to cu130 (the cu13 binary then got
165+ # pulled into a CUDA 12.8 runtime, producing missing libcublas.so.13
166+ # errors at import time). We provide the CUDA-agnostic subset of the
167+ # dependencies here; the matching cuXX binary is pulled by the explicit
168+ # transformer-engine-cu12 / transformer-engine-cu13 extras below.
169+ # Static deps mirror TE 2.15's build_tools/pytorch.py::install_requirements().
170+ [[tool .uv .dependency-metadata ]]
171+ name = " transformer-engine-torch"
172+ requires-dist = [
173+ " torch>=2.1" ,
174+ " einops" ,
175+ " onnxscript" ,
176+ " onnx" ,
177+ " packaging" ,
178+ " pydantic" ,
179+ " nvdlfw-inspect" ,
180+ ]
181+
156182# ####################################################################
157183# Flags Controlling the local build of physicsnemo
158184# ####################################################################
@@ -249,6 +275,21 @@ natten-cu12 = [
249275natten-cu13 = [
250276 " natten>=0.21.5" ,
251277]
278+ # Transformer Engine extras (mutually exclusive via [tool.uv] conflicts).
279+ # Pair with the matching CUDA backend extra, e.g.:
280+ # pip install nvidia-physicsnemo[cu12,transformer-engine-cu12]
281+ # pip install nvidia-physicsnemo[cu13,transformer-engine-cu13]
282+ # Uses the metapackage's `core_cu12` / `core_cu13` extras (introduced in
283+ # transformer-engine 2.14.0) to explicitly pin the CUDA backend binary,
284+ # so the resolution doesn't depend on torch.version.cuda in the build
285+ # env. See [[tool.uv.dependency-metadata]] above for the matching
286+ # override on transformer-engine-torch's dynamic CUDA dep.
287+ transformer-engine-cu12 = [
288+ " transformer_engine[pytorch,core_cu12]>=2.14.0" ,
289+ ]
290+ transformer-engine-cu13 = [
291+ " transformer_engine[pytorch,core_cu13]>=2.14.0" ,
292+ ]
252293utils-extras = [
253294 " wandb" ,
254295 " mlflow>=3.12.0" ,
@@ -296,10 +337,6 @@ sym = [
296337 " sympy>=1.12" ,
297338]
298339
299- perf = [
300- " transformer_engine[pytorch]" ,
301- ]
302-
303340
304341# ####################################################################
305342# Linting configuration
0 commit comments