Skip to content

Commit 567034a

Browse files
authored
Add tokenizer data handler and test case via data config (#487)
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent da8ae5d commit 567034a

6 files changed

Lines changed: 132 additions & 1 deletion

File tree

docs/advanced-data-preprocessing.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ This library currently supports the following [preexisting data handlers](https:
233233
Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates.
234234
- `duplicate_columns`:
235235
Duplicates one column of the dataset to another column.
236+
- `tokenize`:
237+
Tokenizes one column of the dataset passed as input `dataset_text_field`.
236238

237239
These handlers could be requested by their same name and users can lookup the function args from [here](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py)
238240

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,6 @@
4949
DATA_CONFIG_RENAME_RETAIN_COLUMNS = os.path.join(
5050
PREDEFINED_DATA_CONFIGS, "rename_retain_columns.yaml"
5151
)
52+
DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER = os.path.join(
53+
PREDEFINED_DATA_CONFIGS, "tokenize_using_handler_and_train.yaml"
54+
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dataprocessor:
2+
type: default
3+
datasets:
4+
- name: non_tokenized_dataset
5+
data_paths:
6+
- "FILE_PATH"
7+
data_handlers:
8+
- name: tokenize
9+
arguments:
10+
remove_columns: all
11+
batched: true
12+
fn_kwargs:
13+
dataset_text_field: "output"
14+
truncation: True
15+
max_length: 1024
16+
- name: duplicate_columns
17+
arguments:
18+
remove_columns: all
19+
batched: true
20+
fn_kwargs:
21+
old_column: "input_ids"
22+
new_column: "labels"

tests/data/test_data_handlers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
apply_custom_jinja_template,
3636
combine_sequence,
3737
duplicate_columns,
38+
tokenize,
3839
)
40+
from tuning.data.setup_dataprocessor import is_pretokenized_dataset
3941

4042

4143
def test_apply_custom_formatting_template():
@@ -250,3 +252,26 @@ def test_duplicate_columns_copies_columns():
250252
assert new in first_element
251253
assert old in first_element
252254
assert first_element[new] == first_element[old]
255+
256+
257+
def test_tokenizer_data_handler_tokenizes():
258+
"Ensure tokenizer data handler tokenizes the input properly with proper truncation"
259+
d = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA_JSONL)
260+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
261+
dataset_text_field = "output"
262+
truncation = True
263+
max_length = 10
264+
265+
updated_dataaset = d.map(
266+
tokenize,
267+
fn_kwargs={
268+
"tokenizer": tokenizer,
269+
"dataset_text_field": dataset_text_field,
270+
"truncation": truncation,
271+
"max_length": max_length,
272+
},
273+
)
274+
275+
assert "input_ids" in updated_dataaset["train"][0]
276+
for element in updated_dataaset["train"]:
277+
assert len(element["input_ids"]) <= max_length

tests/test_sft_trainer.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
DATA_CONFIG_MULTITURN_DATA_YAML,
4343
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
4444
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
45+
DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER,
4546
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT,
4647
DATA_CONFIG_YAML_STREAMING_PRETOKENIZED,
4748
)
@@ -996,6 +997,52 @@ def test_run_training_with_pretokenised_dataset_containing_input_ids():
996997
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
997998

998999

1000+
def test_run_training_with_data_tokenized_using_tokenizer_handler():
1001+
"""Ensure that we can train on non tokenized dataset works by tokenizing using
1002+
tokenizer data handler via data config."""
1003+
with tempfile.TemporaryDirectory() as tempdir:
1004+
1005+
data_args = copy.deepcopy(DATA_ARGS)
1006+
1007+
# set training_data_path and response_template to none
1008+
data_args.response_template = None
1009+
data_args.training_data_path = None
1010+
1011+
dataconfigfile = DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER
1012+
datapath = TWITTER_COMPLAINTS_DATA_JSONL
1013+
1014+
# add data_paths in data_config file
1015+
with tempfile.NamedTemporaryFile(
1016+
"w", delete=False, suffix=".yaml"
1017+
) as temp_yaml_file:
1018+
with open(dataconfigfile, "r", encoding="utf-8") as f:
1019+
data = yaml.safe_load(f)
1020+
datasets = data["datasets"]
1021+
for _, d in enumerate(datasets):
1022+
d["data_paths"] = [datapath]
1023+
yaml.dump(data, temp_yaml_file)
1024+
data_args.data_config_path = temp_yaml_file.name
1025+
1026+
train_args = copy.deepcopy(TRAIN_ARGS)
1027+
train_args.output_dir = tempdir
1028+
1029+
sft_trainer.train(MODEL_ARGS, data_args, train_args)
1030+
1031+
# validate full ft configs
1032+
_validate_training(tempdir)
1033+
checkpoint_path = _get_checkpoint_path(tempdir)
1034+
1035+
# Load the model
1036+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
1037+
1038+
# Run inference on the text
1039+
output_inference = loaded_model.run(
1040+
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
1041+
)
1042+
assert len(output_inference) > 0
1043+
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
1044+
1045+
9991046
@pytest.mark.parametrize(
10001047
"dataset_path",
10011048
[CHAT_DATA_SINGLE_TURN, CHAT_DATA_MULTI_TURN],

tuning/data/data_handlers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Definition of some predefined data preprocessing functions that we need.
1616

1717
# Standard
18-
from typing import Dict, List
18+
from typing import Dict, List, Union
1919
import copy
2020
import re
2121

@@ -257,6 +257,37 @@ def apply_tokenizer_chat_template(
257257
}
258258

259259

260+
def tokenize(
261+
element: Union[Dict[str, str], Dict[str, List]],
262+
tokenizer: AutoTokenizer,
263+
dataset_text_field: str,
264+
truncation: Union[bool, str] = None,
265+
max_length: int = None,
266+
**kwargs,
267+
):
268+
"""Function (data handler) to tokenize dataset columns.
269+
Expects to be run as a HF Map API function.
270+
Args:
271+
element: the HF Dataset element.
272+
tokenizer: Tokenizer to be used.
273+
dataset_text_field: the dataset field to tokenize
274+
truncation: Truncation strategy to use, refer the link
275+
(https://huggingface.co/docs/transformers/en/pad_truncation)
276+
max_length: Max length to truncate the samples to.
277+
kwargs: Any additional kwargs that need to be passed to the tokenizer can be passed as
278+
kwargs['tokenizer_kwargs']
279+
Returns:
280+
tokenized dataset elemenent field "dataset_text_field"
281+
"""
282+
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
283+
return tokenizer(
284+
element[dataset_text_field],
285+
truncation=truncation,
286+
max_length=max_length,
287+
**tokenizer_kwargs,
288+
)
289+
290+
260291
def duplicate_columns(
261292
element: Dict[str, str],
262293
old_column: str,
@@ -298,4 +329,5 @@ def duplicate_columns(
298329
"apply_custom_jinja_template": apply_custom_jinja_template,
299330
"apply_tokenizer_chat_template": apply_tokenizer_chat_template,
300331
"duplicate_columns": duplicate_columns,
332+
"tokenize": tokenize,
301333
}

0 commit comments

Comments
 (0)