Skip to content

Commit 16d8602

Browse files
committed
build(torch): fix 10.3f TransformerEngine build
1 parent fe2ecca commit 16d8602

1 file changed

Lines changed: 17 additions & 3 deletions

File tree

torch/Dockerfile

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,15 +584,18 @@ RUN --mount=type=bind,from=transformerengine-downloader,source=/git/TransformerE
584584
--mount=type=tmpfs,target=/tmp \
585585
. /opt/sccache-start.sh && \
586586
{ \
587-
sed -i -E 's@list\(APPEND NVTE_SPECIFIC_ARCHS "103a"\)@list(APPEND NVTE_SPECIFIC_ARCHS "103a" "103f")@' TransformerEngine/transformer_engine/common/CMakeLists.txt && \
588-
grep -qF '"103a" "103f"' TransformerEngine/transformer_engine/common/CMakeLists.txt \
589-
|| { echo 'Failed to apply sed patch' >&2; exit 1; }; \
587+
awk '{ print } /list\(APPEND NVTE_SPECIFIC_ARCHS "103a"\)/ { print " list(APPEND NVTE_GENERIC_ARCHS \"103f\")"; found = 1 } END { if (!found) exit 1 }' \
588+
TransformerEngine/transformer_engine/common/CMakeLists.txt > /tmp/NVTE.CMakeLists.txt && \
589+
mv /tmp/NVTE.CMakeLists.txt TransformerEngine/transformer_engine/common/CMakeLists.txt && \
590+
grep -qF 'list(APPEND NVTE_GENERIC_ARCHS "103f")' TransformerEngine/transformer_engine/common/CMakeLists.txt \
591+
|| { echo 'Failed to patch TransformerEngine' >&2; exit 1; }; \
590592
} && \
591593
export MAX_JOBS=$(./scale.sh "$(./effective_cpu_count.sh)" 7 32) && \
592594
export MAX_JOBS="${BUILD_MAX_JOBS:-$MAX_JOBS}" && \
593595
export NVTE_BUILD_THREADS_PER_JOB=7 && \
594596
echo "MAX_JOBS: ${MAX_JOBS}; NVTE_BUILD_THREADS_PER_JOB: ${NVTE_BUILD_THREADS_PER_JOB}" && \
595597
export NVCC_APPEND_FLAGS="$(cat /build/nvcc.conf)" && \
598+
NVCC_APPEND_FLAGS="$(echo "${NVCC_APPEND_FLAGS}" | sed -E 's@[[:space:]]?-gencode=[^[:space:]]+@@g')" && \
596599
echo "NVCC_APPEND_FLAGS: ${NVCC_APPEND_FLAGS}" && \
597600
case "${CUDA_VERSION}" in 12.[0123456].*) \
598601
export NVTE_CUDA_ARCHS="${NVTE_CUDA_ARCHS%;100*}" ;; \
@@ -759,6 +762,17 @@ ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#13.?.?}"
759762
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
760763
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
761764
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
765+
# Expose 10.3f to downstream extension builds when the CUDA toolkit can build it.
766+
ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.9.*}||${TORCH_CUDA_ARCH_LIST}||${TORCH_CUDA_ARCH_LIST} 10.3f"
767+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#12.9.?}"
768+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
769+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
770+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
771+
ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##13.*}||${TORCH_CUDA_ARCH_LIST}||${TORCH_CUDA_ARCH_LIST} 10.3f"
772+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#13.?.?}"
773+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
774+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
775+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
762776

763777
COPY --link --from=torch-common --chmod=755 install_cudnn.sh /tmp/install_cudnn.sh
764778
# - libnvjitlink-X-Y only exists for CUDA versions >= 12-0.

0 commit comments

Comments
 (0)