Skip to content

Commit 24a942e

Browse files
authored
Merge branch 'main' into dcp-hf-util
2 parents 5239d55 + 34c3a4f commit 24a942e

14 files changed

Lines changed: 145 additions & 24 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
- [Advanced Data Processing](./docs/advanced-data-preprocessing.md#data-config)
99
- [Guidelines on supported data formats](./docs/advanced-data-preprocessing.md#use-cases-supported-via-command-line-argument-training_data_path)
1010
- [Offline data processing](#offline-data-preprocessing)
11-
- [Online data mixing](./docs/online-data-mixing.md)
11+
- [Online data mixing](./docs/advanced-data-preprocessing.md#online-data-mixing-section)
1212
- [Additional Frameworks](#additional-frameworks)
1313
- [Inference](#inference)
1414
- [Validation](#validation)

build/Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ ARG ENABLE_MLFLOW=false
2525
ARG ENABLE_FMS_ACCELERATION=true
2626
ARG ENABLE_SCANNER=false
2727
ARG ENABLE_CLEARML=false
28+
ARG ENABLE_RECOMMENDER=true
2829

2930
## Base Layer ##################################################################
3031
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base
@@ -188,6 +189,9 @@ RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
188189
RUN if [[ "${ENABLE_CLEARML}" == "true" ]]; then \
189190
python -m pip install --user "$(head bdist_name)[clearml]"; \
190191
fi
192+
RUN if [[ "${ENABLE_RECOMMENDER}" == "true" ]]; then \
193+
python -m pip install --user "$(head bdist_name)[tuning-config-recommender]"; \
194+
fi
191195

192196
# Clean up the wheel module. It's only needed by flash-attn install
193197
RUN python -m pip uninstall wheel build -y && \

build/nvcr.Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ ARG ENABLE_MLFLOW=false
3434
ARG ENABLE_SCANNER=false
3535
ARG ENABLE_CLEARML=true
3636
ARG ENABLE_TRITON_KERNELS=true
37+
ARG ENABLE_RECOMMENDER=true
3738

3839
# Ensures to always build mamba_ssm from source
3940
ENV PIP_NO_BINARY=mamba-ssm,mamba_ssm
@@ -76,6 +77,9 @@ RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
7677
RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
7778
pip install --no-cache-dir ${SOURCE_DIR}[scanner-dev]; \
7879
fi
80+
RUN if [[ "${ENABLE_RECOMMENDER}" == "true" ]]; then \
81+
python -m pip install --user "$(head bdist_name)[tuning-config-recommender]"; \
82+
fi
7983

8084
# cleanup
8185
RUN rm -rf /root/.cache /tmp/* /opt/pytorch

docs/advanced-data-preprocessing.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ Each data handler has:
162162
- `sampling` (optional, float): The sampling ratio (0.0 to 1.0) with which to sample a dataset in case of interleaving.
163163
- `split` (optional, dict[str: float]): Defines how to split the dataset into training and validation sets. Requires both `train` and `validation` keys.
164164
- `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset.
165-
165+
- `dataset_split_name` (optional, str): Name of the dataset split. This is useful for loading HuggingFace datasets with split names that are different from the standard (eg: `train_sft` instead of `train`). If no `dataset_split_name` is provided, `train` is used.
166+
- `shuffle` (optional, bool): If the dataset should be shuffled while splitting into train and validation split. Defaults to `True`. Use caution when using this field and only use when the dataset is already shuffled.
166167

167168
We do provide some sample `data_configs` here, [predefined_data_configs](../tests/artifacts/predefined_data_configs/).
168169

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ dependencies = [
3838
"peft>=0.18.0,< 0.19.0",
3939
"datasets>=4.0.0,<5.0.0",
4040
"simpleeval>=0.9.13,<2.0",
41-
"pillow>=11.0.0,<12.0",
41+
"pillow>=12.1.1",
4242
"kernels<=0.9.0",
4343
]
4444

4545
[project.optional-dependencies]
46-
dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"]
46+
dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0", "hf_transfer>=0.1.9"]
4747
flash-attn = ["flash-attn>=2.8.3"]
4848
aim = ["aim>=3.19.0,<4.0"]
4949
mlflow = ["mlflow"]
@@ -60,6 +60,7 @@ fms-accel-all = [
6060
"fms-acceleration-moe",
6161
"fms-acceleration-odm"
6262
]
63+
tuning-config-recommender=["tuning-config-recommender>=0.1.5"]
6364

6465
[tool.setuptools.packages.find]
6566
exclude = ["tests", "tests.*"]

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,6 @@
8686
DATA_CONFIG_SKIP_LARGE_COLUMNS_HANDLER = os.path.join(
8787
PREDEFINED_DATA_CONFIGS, "skip_large_columns_data_handler_template.yaml"
8888
)
89+
DATA_CONFIG_CUSTOM_SPLIT_NAME = os.path.join(
90+
PREDEFINED_DATA_CONFIGS, "dataset_with_custom_split.yaml"
91+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
dataprocessor:
2+
type: default
3+
sampling_stopping_strategy: first_exhausted
4+
seed: 66
5+
datasets:
6+
- name: dataset_split_custom_split_name
7+
split:
8+
train: 0.8
9+
validation: 0.2
10+
sampling: 0.5
11+
dataset_split_name: "train_sft"
12+
data_paths:
13+
- "FILE_PATH"
14+
data_handlers:
15+
- name: tokenize_and_apply_chat_template_with_masking
16+
arguments:
17+
remove_columns: all
18+
batched: false
19+
fn_kwargs:
20+
formatted_text_column_name: "formatted_chat_data"
21+
conversation_column: "messages"
22+
- name: dataset_wo_split_custom_split_name
23+
sampling: 0.5
24+
dataset_split_name: "train_sft"
25+
data_paths:
26+
- "FILE_PATH"
27+
data_handlers:
28+
- name: tokenize_and_apply_chat_template_with_masking
29+
arguments:
30+
remove_columns: all
31+
batched: false
32+
fn_kwargs:
33+
formatted_text_column_name: "formatted_chat_data"
34+
conversation_column: "messages"

tests/artifacts/testdata/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
IMAGE_DATASET = os.path.join(JSONL_DATA_DIR, "image_dataset.jsonl")
8787
EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json")
8888
MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json")
89+
CHAT_DATA_HF_HOSTED_CUSTOM_SPLIT = "rom7/test-OpenHermes-2.5-H4"
8990

9091
# Other constants
9192
CUSTOM_TOKENIZER_TINYLLAMA = os.path.join(

tests/test_sft_trainer.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from tests.artifacts.language_models import MAYKEYE_TINY_LLAMA_CACHED, TINYMIXTRAL_MOE
4343
from tests.artifacts.predefined_data_configs import (
4444
CHAT_TEMPLATE_JINJA,
45+
DATA_CONFIG_CUSTOM_SPLIT_NAME,
4546
DATA_CONFIG_DUPLICATE_COLUMNS,
4647
DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE,
4748
DATA_CONFIG_MULTIPLE_DATASETS_ODM_YAML,
@@ -61,6 +62,7 @@
6162
GRANITE_3_1_B_CHAT_TEMPLATE,
6263
)
6364
from tests.artifacts.testdata import (
65+
CHAT_DATA_HF_HOSTED_CUSTOM_SPLIT,
6466
CHAT_DATA_MULTI_TURN,
6567
CHAT_DATA_MULTI_TURN_CONVERSATIONS,
6668
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
@@ -949,16 +951,16 @@ def test_run_causallm_lora_add_special_tokens():
949951
["lm_head"],
950952
["embed_tokens"],
951953
marks=pytest.mark.skipif(
952-
version.parse(peft.__version__) <= version.parse("0.18.0"),
953-
reason="Not released in PEFT <= 0.18.0",
954+
version.parse(peft.__version__) <= version.parse("0.18.1"),
955+
reason="Not released in PEFT <= 0.18.1",
954956
),
955957
),
956958
pytest.param(
957959
["embed_tokens", "lm_head"],
958960
["embed_tokens"],
959961
marks=pytest.mark.skipif(
960-
version.parse(peft.__version__) <= version.parse("0.18.0"),
961-
reason="Not released in PEFT <= 0.18.0",
962+
version.parse(peft.__version__) <= version.parse("0.18.1"),
963+
reason="Not released in PEFT <= 0.18.1",
962964
),
963965
),
964966
],
@@ -1010,8 +1012,8 @@ def test_run_causallm_lora_tied_weights_in_modules_to_save(modules_to_save, expe
10101012
],
10111013
)
10121014
@pytest.mark.skipif(
1013-
version.parse(peft.__version__) <= version.parse("0.18.0"),
1014-
reason="Not released in PEFT <= 0.18.0",
1015+
version.parse(peft.__version__) <= version.parse("0.18.1"),
1016+
reason="Not released in PEFT <= 0.18.1",
10151017
)
10161018
def test_run_causallm_lora_tied_weights_in_target_modules(target_modules, expected):
10171019
"""Check if a model with tied weights in target_modules is correctly trained"""
@@ -1916,6 +1918,61 @@ def test_run_moe_ft_with_save_model_dir(dataset_path):
19161918
assert os.path.exists(os.path.join(save_model_dir))
19171919

19181920

1921+
@pytest.mark.parametrize(
1922+
"datafiles, dataconfigfile",
1923+
[
1924+
(
1925+
[CHAT_DATA_HF_HOSTED_CUSTOM_SPLIT, CHAT_DATA_HF_HOSTED_CUSTOM_SPLIT],
1926+
DATA_CONFIG_CUSTOM_SPLIT_NAME,
1927+
)
1928+
],
1929+
)
1930+
def test_run_chat_style_ft_using_custom_split_name(datafiles, dataconfigfile):
1931+
"""Check if we can select custom split for a dataset."""
1932+
with tempfile.TemporaryDirectory() as tempdir:
1933+
data_args = copy.deepcopy(DATA_ARGS)
1934+
data_args.training_data_path = None
1935+
data_args.response_template = None
1936+
data_args.dataset_text_field = None
1937+
data_args.chat_template = CHAT_TEMPLATE_JINJA
1938+
1939+
model_args = copy.deepcopy(MODEL_ARGS)
1940+
model_args.model_name_or_path = TINYMIXTRAL_MOE
1941+
model_args.tokenizer_name_or_path = TINYMIXTRAL_MOE
1942+
1943+
train_args = copy.deepcopy(TRAIN_ARGS)
1944+
train_args.output_dir = tempdir
1945+
1946+
with tempfile.NamedTemporaryFile(
1947+
"w", delete=False, suffix=".yaml"
1948+
) as temp_yaml_file:
1949+
with open(dataconfigfile, "r", encoding="utf-8") as f:
1950+
data = yaml.safe_load(f)
1951+
datasets = data["datasets"]
1952+
for i, d in enumerate(datasets):
1953+
d["data_paths"] = [datafiles[i]]
1954+
yaml.dump(data, temp_yaml_file)
1955+
data_args.data_config_path = temp_yaml_file.name
1956+
1957+
sft_trainer.train(model_args, data_args, train_args)
1958+
1959+
# validate the configs
1960+
_validate_training(tempdir)
1961+
checkpoint_path = _get_checkpoint_path(tempdir)
1962+
1963+
# Load the model
1964+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
1965+
1966+
# Run inference on the text
1967+
output_inference = loaded_model.run(
1968+
'<|user|>\nProvide two rhyming words for the word "love"\n\
1969+
<nopace></s><|assistant|>',
1970+
max_new_tokens=50,
1971+
)
1972+
assert len(output_inference) > 0
1973+
assert 'Provide two rhyming words for the word "love"' in output_inference
1974+
1975+
19191976
############################# Helper functions #############################
19201977
def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
19211978
train_args = copy.deepcopy(training_args)

tox.ini

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ envlist = py, lint, fmt
55
description = run unit tests
66
deps =
77
pytest>=7
8-
.[aim,mlflow,clearml,scanner-dev]
8+
.[mlflow,clearml,scanner-dev]
99
commands =
1010
pytest {posargs:tests}
1111

@@ -56,9 +56,9 @@ commands =
5656

5757
[testenv:accel]
5858
description = run all unit tests including requring GPU support
59-
deps =
59+
deps =
6060
pytest>=7
61-
.[aim,mlflow,clearml,scanner-dev,fms-accel-all]
61+
.[mlflow,clearml,scanner-dev,fms-accel-all]
6262
setenv =
6363
CUDA_VISIBLE_DEVICES=0
6464
commands_pre =
@@ -74,9 +74,9 @@ commands =
7474

7575
[testenv:gpu]
7676
description = run all unit tests including requring GPU support
77-
deps =
77+
deps =
7878
pytest>=7
79-
.[aim,mlflow,clearml,scanner-dev,fms-accel-all]
79+
.[mlflow,clearml,scanner-dev,fms-accel-all]
8080
setenv =
8181
CUDA_VISIBLE_DEVICES=0
8282
commands_pre =

0 commit comments

Comments
 (0)