|
| 1 | +# Copyright The FMS HF Tuning Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +## Global Args ################################################################# |
| 16 | +## If the nvcr container is updated, ensure to check the torch and python |
| 17 | +## installation version inside the dockerfile before pushing changes. |
| 18 | +ARG NVCR_IMAGE_VERSION=25.02-py3 |
| 19 | + |
| 20 | +# This is based on what is inside the NVCR image already |
| 21 | +ARG PYTHON_VERSION=3.12 |
| 22 | + |
| 23 | +## Base Layer ################################################################## |
| 24 | +FROM nvcr.io/nvidia/pytorch:${NVCR_IMAGE_VERSION} AS dev |
| 25 | + |
| 26 | +ARG USER=root |
| 27 | +ARG USER_UID=0 |
| 28 | +ARG WORKDIR=/app |
| 29 | +ARG SOURCE_DIR=${WORKDIR}/fms-hf-tuning |
| 30 | + |
| 31 | +ARG ENABLE_FMS_ACCELERATION=true |
| 32 | +ARG ENABLE_AIM=true |
| 33 | +ARG ENABLE_ALORA=true |
| 34 | +ARG ENABLE_MLFLOW=true |
| 35 | +ARG ENABLE_SCANNER=true |
| 36 | +ARG ENABLE_CLEARML=true |
| 37 | +ARG ENABLE_TRITON_KERNELS=true |
| 38 | +ARG ENABLE_MAMBA_SUPPORT=true |
| 39 | + |
| 40 | +# Ensures to always build mamba_ssm from source |
| 41 | +ENV PIP_NO_BINARY=mamba-ssm,mamba_ssm |
| 42 | + |
| 43 | +RUN python -m pip install --upgrade pip |
| 44 | + |
| 45 | +# upgrade torch as the base layer contains only torch 2.7 |
| 46 | +RUN pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128 |
| 47 | + |
| 48 | +# Install main package + flash attention |
| 49 | +RUN COPY . ${SOURCE_DIR} |
| 50 | +RUN cd ${SOURCE_DIR} |
| 51 | +RUN pip install --no-cache-dir ${SOURCE_DIR} && \ |
| 52 | + pip install --no-cache-dir ${SOURCE_DIR}[flash-attn] |
| 53 | + |
| 54 | +# Optional extras |
| 55 | +RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \ |
| 56 | + pip install --no-cache-dir ${SOURCE_DIR}[fms-accel] && \ |
| 57 | + python -m fms_acceleration.cli install fms_acceleration_peft && \ |
| 58 | + python -m fms_acceleration.cli install fms_acceleration_foak && \ |
| 59 | + python -m fms_acceleration.cli install fms_acceleration_aadp && \ |
| 60 | + python -m fms_acceleration.cli install fms_acceleration_moe; \ |
| 61 | + fi |
| 62 | + |
| 63 | +RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \ |
| 64 | + pip install --no-cache-dir ${SOURCE_DIR}[activated-lora]; \ |
| 65 | + fi |
| 66 | +RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \ |
| 67 | + pip install --no-cache-dir ${SOURCE_DIR}[aim]; \ |
| 68 | + fi |
| 69 | +RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \ |
| 70 | + pip install --no-cache-dir ${SOURCE_DIR}[mlflow]; \ |
| 71 | + fi |
| 72 | +RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \ |
| 73 | + pip install --no-cache-dir ${SOURCE_DIR}[scanner-dev]; \ |
| 74 | + fi |
| 75 | +RUN if [[ "${ENABLE_CLEARML}" == "true" ]]; then \ |
| 76 | + pip install --no-cache-dir ${SOURCE_DIR}[clearml]; \ |
| 77 | + fi |
| 78 | +RUN if [[ "${ENABLE_MAMBA_SUPPORT}" == "true" ]]; then \ |
| 79 | + pip install --no-cache-dir ${SOURCE_DIR}[mamba]; \ |
| 80 | + fi |
| 81 | +RUN if [[ "${ENABLE_TRITON_KERNELS}" == "true" ]]; then \ |
| 82 | + pip install --no-cache-dir "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"; \ |
| 83 | + fi |
| 84 | + |
| 85 | +RUN chmod -R g+rwX $WORKDIR /tmp |
| 86 | +RUN mkdir -p /.cache && chmod -R 777 /.cache |
| 87 | + |
| 88 | +# Set Triton environment variables for qLoRA |
| 89 | +ENV TRITON_HOME="/tmp/triton_home" |
| 90 | +ENV TRITON_DUMP_DIR="/tmp/triton_dump_dir" |
| 91 | +ENV TRITON_CACHE_DIR="/tmp/triton_cache_dir" |
| 92 | +ENV TRITON_OVERRIDE_DIR="/tmp/triton_override_dir" |
| 93 | + |
| 94 | +WORKDIR $WORKDIR |
| 95 | + |
| 96 | +CMD ["${SOURCE_DIR}/build/accelerate_launch.py"] |
0 commit comments