Skip to content

Commit 3ca36f6

Browse files
authored
Merge pull request #28 from coreweave/es/torch-extras
feat(torch-extras): Add NVIDIA Apex
2 parents 36522e7 + 7b43111 commit 3ca36f6

2 files changed

Lines changed: 93 additions & 1 deletion

File tree

torch-extras/Dockerfile

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
ARG BASE_IMAGE
44
ARG DEEPSPEED_VERSION="0.9.4"
55
ARG FLASH_ATTN_VERSION="1.0.7"
6+
ARG APEX_COMMIT="7b2e71b0d4013f8e2f9f1c8dd21980ff1d76f1b6"
67

78
FROM alpine/git:2.36.3 as flash-attn-downloader
89
WORKDIR /git
@@ -11,6 +12,16 @@ RUN git clone --recurse-submodules --shallow-submodules -j8 --depth 1 \
1112
https://github.com/HazyResearch/flash-attention -b v${FLASH_ATTN_VERSION} && \
1213
rm -rf flash-attention/.git
1314

15+
FROM alpine/git:2.36.3 as apex-downloader
16+
WORKDIR /git
17+
ARG APEX_COMMIT
18+
RUN git clone --filter=blob:none --depth 1 --no-single-branch --no-checkout \
19+
https://github.com/NVIDIA/apex && \
20+
cd apex && \
21+
git checkout "${APEX_COMMIT}" && \
22+
git submodule update --init --recursive --depth 1 --jobs 8 && \
23+
find -type d -name docs -exec rm -r '{}' ';' 2> /dev/null
24+
1425

1526
# Dependencies requiring NVCC are built ahead of time in a separate stage
1627
# so that the ~2 GiB dev library installations don't have to be included
@@ -54,6 +65,8 @@ WORKDIR /build
5465
COPY compiler_wrapper.f95 .
5566
RUN gfortran -O3 ./compiler_wrapper.f95 -o ./compiler && rm ./compiler_wrapper.f95
5667

68+
COPY --chmod=755 effective_cpu_count.sh .
69+
5770

5871
FROM builder-base as deepspeed-builder
5972
# DeepSpeed build flags
@@ -92,6 +105,7 @@ RUN python3 -m pip install -U --no-cache-dir \
92105
do if [[ -z ${!VAR} ]]; then unset ${VAR}; fi; done; \
93106
} && \
94107
CC=$(realpath -e ./compiler) \
108+
MAX_JOBS=$(($(./effective_cpu_count.sh) + 2)) \
95109
python3 -m pip wheel -w /wheels \
96110
--no-cache-dir --no-build-isolation --no-deps \
97111
deepspeed==${DEEPSPEED_VERSION} && \
@@ -102,12 +116,12 @@ WORKDIR /wheels
102116

103117

104118
FROM builder-base as flash-attn-builder
105-
ARG FLASH_ATTN_VERSION
106119

107120
RUN --mount=type=bind,from=flash-attn-downloader,source=/git/flash-attention,target=flash-attention/,rw \
108121
python3 -m pip install -U --no-cache-dir \
109122
packaging setuptools wheel pip && \
110123
export CC=$(realpath -e ./compiler) && \
124+
export MAX_JOBS=$(($(./effective_cpu_count.sh) + 2)) && \
111125
cd flash-attention && \
112126
parallel 'cd {} && python3 setup.py bdist_wheel --dist-dir /wheels' ::: \
113127
. \
@@ -121,6 +135,50 @@ RUN --mount=type=bind,from=flash-attn-downloader,source=/git/flash-attention,tar
121135
WORKDIR /wheels
122136

123137

138+
FROM builder-base as apex-builder
139+
140+
RUN LIBNCCL2_VERSION=$(dpkg-query --showformat='${Version}' --show libnccl2) && \
141+
apt-get -qq update && apt-get install -y --no-install-recommends \
142+
libnccl-dev=$LIBNCCL2_VERSION && \
143+
apt-get clean
144+
145+
RUN --mount=type=bind,from=apex-downloader,source=/git/apex,target=apex/,rw \
146+
python3 -m pip install -U --no-cache-dir \
147+
packaging setuptools wheel pip && \
148+
export CC=$(realpath -e ./compiler) && \
149+
export MAX_JOBS=$(($(./effective_cpu_count.sh) + 2)) && \
150+
EXTENSIONS=$(printf -- '--config-settings "--build-option=%s" ' $( \
151+
echo \
152+
--cpp_ext \
153+
--cuda_ext \
154+
--permutation_search \
155+
--xentropy \
156+
--focal_loss \
157+
--index_mul_2d \
158+
--deprecated_fused_adam \
159+
--deprecated_fused_lamb \
160+
--fast_layer_norm \
161+
--fmha \
162+
--fast_multihead_attn \
163+
--transducer \
164+
--peer_memory \
165+
--nccl_p2p \
166+
--fast_bottleneck && \
167+
if dpkg-query --status libcudnn8-dev > /dev/null 2> /dev/null; then \
168+
echo \
169+
--bnp \
170+
--cudnn_gbn \
171+
--fused_conv_bias_relu; \
172+
fi; \
173+
)) && \
174+
cd apex && \
175+
python3 -m pip wheel -w /wheels -v \
176+
--no-cache-dir --no-build-isolation --no-deps \
177+
$EXTENSIONS ./
178+
179+
WORKDIR /wheels
180+
181+
124182
FROM ${BASE_IMAGE}
125183

126184
RUN apt-get -qq update && \
@@ -131,4 +189,6 @@ RUN --mount=type=bind,from=deepspeed-builder,source=/wheels,target=/tmp/wheels \
131189
python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
132190
RUN --mount=type=bind,from=flash-attn-builder,source=/wheels,target=/tmp/wheels \
133191
python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
192+
RUN --mount=type=bind,from=apex-builder,source=/wheels,target=/tmp/wheels \
193+
python3 -m pip install --no-cache-dir /tmp/wheels/*.whl
134194
RUN rm -r /tmp/wheels
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/bin/sh
2+
3+
CPU_QUOTA() (
4+
CGROUP='/sys/fs/cgroup';
5+
CGROUP_V1="$CGROUP/cpu,cpuacct";
6+
CGROUP_V1_QUOTA="$CGROUP_V1/cpu.cfs_quota_us";
7+
CGROUP_V1_PERIOD="$CGROUP_V1/cpu.cfs_period_us";
8+
CGROUP_V2="$CGROUP/user.slice/cpu.max";
9+
if [ ! -d "$CGROUP" ]; then
10+
return 1;
11+
elif [ -f "$CGROUP_V1_QUOTA" ] && [ -f "$CGROUP_V1_PERIOD" ]; then
12+
IFS='' read -r QUOTA 2> /dev/null < "$CGROUP_V1_QUOTA" || return 1;
13+
IFS='' read -r PERIOD 2> /dev/null < "$CGROUP_V1_PERIOD" || return 1;
14+
elif [ -f "$CGROUP_V2" ]; then
15+
IFS=' ' read -r QUOTA PERIOD 2> /dev/null < "$CGROUP_V2" || return 1;
16+
else
17+
return 1;
18+
fi;
19+
20+
if [ "$QUOTA" -gt 0 ] 2> /dev/null && [ "$PERIOD" -gt 0 ] 2> /dev/null; then
21+
echo $((QUOTA / PERIOD));
22+
return 0;
23+
else
24+
return 1;
25+
fi;
26+
)
27+
28+
EFFECTIVE_CPU_COUNT() {
29+
CPU_QUOTA || getconf _NPROCESSORS_ONLN;
30+
}
31+
32+
EFFECTIVE_CPU_COUNT;

0 commit comments

Comments
 (0)