Skip to content

Commit a2a9824

Browse files
committed
fix few errors
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent e15e842 commit a2a9824

3 files changed

Lines changed: 15 additions & 52 deletions

File tree

.github/workflows/pr-command.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,6 @@ jobs:
6464
username: ${{ secrets.QUAY_USERNAME }}
6565
password: ${{ secrets.QUAY_ROBOT_TOKEN }}
6666

67-
- name: Run basic sanity checks/tests on the new image before pushing
68-
run: |
69-
echo 'check if accelerate is installed and in the PATH'
70-
IMAGE_NAME=${{ vars.QUAY_REPOSITORY }}fms-hf-tuning:main-nvcr-latest
71-
docker run --rm -it --entrypoint which "$IMAGE_NAME" accelerate
72-
echo 'checks done'
73-
7467
- name: Push docker image
7568
run: |
7669
docker push ${{ vars.QUAY_REPOSITORY }}fms-hf-tuning:pr-${{ github.event.issue.number }}-nvcr

build/nvcr.Dockerfile

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
## Global Args #################################################################
1616
## If the nvcr container is updated, ensure to check the torch and python
1717
## installation version inside the dockerfile before pushing changes.
18-
ARG NVCR_IMAGE_VERSION=25.02-py3
18+
ARG NVCR_IMAGE_VERSION=25.10-py3
1919

2020
# This is based on what is inside the NVCR image already
2121
ARG PYTHON_VERSION=3.12
@@ -28,58 +28,26 @@ ARG USER_UID=0
2828
ARG WORKDIR=/app
2929
ARG SOURCE_DIR=${WORKDIR}/fms-hf-tuning
3030

31-
ARG ENABLE_FMS_ACCELERATION=true
32-
ARG ENABLE_AIM=false
33-
ARG ENABLE_MLFLOW=false
34-
ARG ENABLE_SCANNER=false
35-
ARG ENABLE_CLEARML=true
3631
ARG ENABLE_TRITON_KERNELS=true
37-
ARG ENABLE_RECOMMENDER=true
3832

3933
# Ensures to always build mamba_ssm from source
4034
ENV PIP_NO_BINARY=mamba-ssm,mamba_ssm
4135

42-
# upgrade torch as the base layer contains only torch 2.7
43-
RUN python -m pip install --upgrade pip && \
44-
pip install --upgrade setuptools && \
45-
pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128
36+
# install triton kernels
37+
RUN pip install --no-cache-dir "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
4638

4739
# Install main package + flash attention
4840
COPY . ${SOURCE_DIR}
4941
RUN cd ${SOURCE_DIR}
5042

51-
RUN pip install --no-cache-dir ${SOURCE_DIR} && \
52-
pip install --no-cache-dir --no-build-isolation ${SOURCE_DIR}[flash-attn] && \
53-
pip install --no-cache-dir --no-build-isolation ${SOURCE_DIR}[mamba]
54-
55-
# Optional extras
56-
RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
57-
pip install --no-cache-dir ${SOURCE_DIR}[fms-accel] && \
58-
python -m fms_acceleration.cli install fms_acceleration_peft && \
59-
python -m fms_acceleration.cli install fms_acceleration_foak && \
60-
python -m fms_acceleration.cli install fms_acceleration_aadp && \
61-
python -m fms_acceleration.cli install fms_acceleration_moe && \
62-
python -m fms_acceleration.cli install fms_acceleration_odm; \
63-
fi
64-
65-
RUN if [[ "${ENABLE_TRITON_KERNELS}" == "true" ]]; then \
66-
pip install --no-cache-dir "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"; \
67-
fi
68-
RUN if [[ "${ENABLE_CLEARML}" == "true" ]]; then \
69-
pip install --no-cache-dir ${SOURCE_DIR}[clearml]; \
70-
fi
71-
RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
72-
pip install --no-cache-dir ${SOURCE_DIR}[aim]; \
73-
fi
74-
RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
75-
pip install --no-cache-dir ${SOURCE_DIR}[mlflow]; \
76-
fi
77-
RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
78-
pip install --no-cache-dir ${SOURCE_DIR}[scanner-dev]; \
79-
fi
80-
RUN if [[ "${ENABLE_RECOMMENDER}" == "true" ]]; then \
81-
pip install --no-cache-dir ${SOURCE_DIR}[tuning_config_recommender]; \
82-
fi
43+
RUN pip install --no-cache-dir ${SOURCE_DIR}[flash-attn,mamba,fms-accel,clearml,tuning_config_recommender]
44+
45+
# install fms-accel packages
46+
RUN python -m fms_acceleration.cli install fms_acceleration_peft && \
47+
python -m fms_acceleration.cli install fms_acceleration_foak && \
48+
python -m fms_acceleration.cli install fms_acceleration_aadp && \
49+
python -m fms_acceleration.cli install fms_acceleration_moe && \
50+
python -m fms_acceleration.cli install fms_acceleration_odm
8351

8452
# cleanup build artifacts and caches
8553
RUN rm -rf /root/.cache /tmp/pip-* \
@@ -141,4 +109,6 @@ ENV TRITON_DUMP_DIR="/tmp/triton_dump_dir"
141109
ENV TRITON_CACHE_DIR="/tmp/triton_cache_dir"
142110
ENV TRITON_OVERRIDE_DIR="/tmp/triton_override_dir"
143111

144-
CMD ["python", "/app/accelerate_launch.py"]
112+
RUN pip install -U accelerate
113+
114+
CMD ["python", "/app/accelerate_launch.py"]

tests/test_sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2192,7 +2192,7 @@ def test_empty_data():
21922192
data_args = copy.deepcopy(DATA_ARGS)
21932193
data_args.training_data_path = EMPTY_DATA
21942194

2195-
with pytest.raises((DatasetGenerationError, ValueError)):
2195+
with pytest.raises((DatasetGenerationError, ValueError, StopIteration)):
21962196
sft_trainer.train(
21972197
copy.deepcopy(MODEL_ARGS),
21982198
data_args,

0 commit comments

Comments
 (0)