33ARG BASE_IMAGE
44ARG DEEPSPEED_VERSION="0.9.4"
55ARG FLASH_ATTN_VERSION="1.0.7"
6+ ARG APEX_COMMIT="7b2e71b0d4013f8e2f9f1c8dd21980ff1d76f1b6"
67
78FROM alpine/git:2.36.3 as flash-attn-downloader
89WORKDIR /git
@@ -11,6 +12,16 @@ RUN git clone --recurse-submodules --shallow-submodules -j8 --depth 1 \
1112 https://github.com/HazyResearch/flash-attention -b v${FLASH_ATTN_VERSION} && \
1213 rm -rf flash-attention/.git
1314
15+ FROM alpine/git:2.36.3 as apex-downloader
16+ WORKDIR /git
17+ ARG APEX_COMMIT
18+ RUN git clone --filter=blob:none --depth 1 --no-single-branch --no-checkout \
19+ https://github.com/NVIDIA/apex && \
20+ cd apex && \
21+ git checkout "${APEX_COMMIT}" && \
22+ git submodule update --init --recursive --depth 1 --jobs 8 && \
23+ find -type d -name docs -exec rm -r '{}' ';' 2> /dev/null
24+
1425
1526# Dependencies requiring NVCC are built ahead of time in a separate stage
1627# so that the ~2 GiB dev library installations don't have to be included
@@ -54,6 +65,8 @@ WORKDIR /build
5465COPY compiler_wrapper.f95 .
5566RUN gfortran -O3 ./compiler_wrapper.f95 -o ./compiler && rm ./compiler_wrapper.f95
5667
68+ COPY --chmod=755 effective_cpu_count.sh .
69+
5770
5871FROM builder-base as deepspeed-builder
5972# DeepSpeed build flags
@@ -92,6 +105,7 @@ RUN python3 -m pip install -U --no-cache-dir \
92105 do if [[ -z ${!VAR} ]]; then unset ${VAR}; fi; done; \
93106 } && \
94107 CC=$(realpath -e ./compiler) \
108+ MAX_JOBS=$(($(./effective_cpu_count.sh) + 2)) \
95109 python3 -m pip wheel -w /wheels \
96110 --no-cache-dir --no-build-isolation --no-deps \
97111 deepspeed==${DEEPSPEED_VERSION} && \
@@ -102,12 +116,12 @@ WORKDIR /wheels
102116
103117
104118FROM builder-base as flash-attn-builder
105- ARG FLASH_ATTN_VERSION
106119
107120RUN --mount=type=bind,from=flash-attn-downloader,source=/git/flash-attention,target=flash-attention/,rw \
108121 python3 -m pip install -U --no-cache-dir \
109122 packaging setuptools wheel pip && \
110123 export CC=$(realpath -e ./compiler) && \
124+ export MAX_JOBS=$(($(./effective_cpu_count.sh) + 2)) && \
111125 cd flash-attention && \
112126 parallel 'cd {} && python3 setup.py bdist_wheel --dist-dir /wheels' ::: \
113127 . \
@@ -121,6 +135,50 @@ RUN --mount=type=bind,from=flash-attn-downloader,source=/git/flash-attention,tar
121135WORKDIR /wheels
122136
123137
138+ FROM builder-base as apex-builder
139+
140+ RUN LIBNCCL2_VERSION=$(dpkg-query --showformat='${Version}' --show libnccl2) && \
141+ apt-get -qq update && apt-get install -y --no-install-recommends \
142+ libnccl-dev=$LIBNCCL2_VERSION && \
143+ apt-get clean
144+
145+ RUN --mount=type=bind,from=apex-downloader,source=/git/apex,target=apex/,rw \
146+ python3 -m pip install -U --no-cache-dir \
147+ packaging setuptools wheel pip && \
148+ export CC=$(realpath -e ./compiler) && \
149+ export MAX_JOBS=$(($(./effective_cpu_count.sh) + 2)) && \
150+ EXTENSIONS=$(printf -- '--config-settings "--build-option=%s" ' $( \
151+ echo \
152+ --cpp_ext \
153+ --cuda_ext \
154+ --permutation_search \
155+ --xentropy \
156+ --focal_loss \
157+ --index_mul_2d \
158+ --deprecated_fused_adam \
159+ --deprecated_fused_lamb \
160+ --fast_layer_norm \
161+ --fmha \
162+ --fast_multihead_attn \
163+ --transducer \
164+ --peer_memory \
165+ --nccl_p2p \
166+ --fast_bottleneck && \
167+ if dpkg-query --status libcudnn8-dev > /dev/null 2> /dev/null; then \
168+ echo \
169+ --bnp \
170+ --cudnn_gbn \
171+ --fused_conv_bias_relu; \
172+ fi; \
173+ )) && \
174+ cd apex && \
175+ python3 -m pip wheel -w /wheels -v \
176+ --no-cache-dir --no-build-isolation --no-deps \
177+ $EXTENSIONS ./
178+
179+ WORKDIR /wheels
180+
181+
124182FROM ${BASE_IMAGE}
125183
126184RUN apt-get -qq update && \
@@ -131,4 +189,6 @@ RUN --mount=type=bind,from=deepspeed-builder,source=/wheels,target=/tmp/wheels \
131189 python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
132190RUN --mount=type=bind,from=flash-attn-builder,source=/wheels,target=/tmp/wheels \
133191 python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
192+ RUN --mount=type=bind,from=apex-builder,source=/wheels,target=/tmp/wheels \
193+ python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
134194RUN rm -r /tmp/wheels
0 commit comments