Skip to content

Commit 0f67235

Browse files
committed
remove extra data formatting handler
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent b09a132 commit 0f67235

7 files changed

Lines changed: 7 additions & 138 deletions

File tree

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
### Constants used for data
2121
PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__))
2222
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
23-
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml"
24-
)
25-
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML = os.path.join(
2623
PREDEFINED_DATA_CONFIGS, "apply_custom_jinja_template.yaml"
2724
)
2825
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join(

tests/artifacts/predefined_data_configs/apply_custom_template_streaming.yaml renamed to tests/artifacts/predefined_data_configs/apply_custom_jinja_template_streaming.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ dataprocessor:
22
type: default
33
streaming: true
44
datasets:
5-
- name: apply_custom_data_template
5+
- name: apply_custom_jinja_template
66
data_paths:
77
- "FILE_PATH"
88
data_handlers:
9-
- name: apply_custom_data_formatting_template
9+
- name: apply_custom_jinja_template
1010
arguments:
1111
remove_columns: all
1212
batched: false

tests/artifacts/predefined_data_configs/apply_custom_template.yaml

Lines changed: 0 additions & 15 deletions
This file was deleted.

tests/data/test_data_handlers.py

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
# Local
3333
from tuning.data.data_handlers import (
34-
apply_custom_data_formatting_template,
3534
apply_custom_jinja_template,
3635
combine_sequence,
3736
duplicate_columns,
@@ -40,34 +39,6 @@
4039
)
4140

4241

43-
def test_apply_custom_formatting_template():
44-
"""Tests custom formatting data handler returns correct formatted response"""
45-
json_dataset = datasets.load_dataset(
46-
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
47-
)
48-
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
49-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
50-
formatted_dataset_field = "formatted_data_field"
51-
formatted_dataset = json_dataset.map(
52-
apply_custom_data_formatting_template,
53-
fn_kwargs={
54-
"tokenizer": tokenizer,
55-
"formatted_text_column_name": formatted_dataset_field,
56-
"template": template,
57-
},
58-
)
59-
# First response from the data file that is read.
60-
expected_response = (
61-
"### Input: @HMRCcustomers No this is my first job"
62-
+ " \n\n ### Response: no complaint"
63-
+ tokenizer.eos_token
64-
)
65-
66-
# a new column is created in Dataset
67-
assert formatted_dataset_field in formatted_dataset["train"][0]
68-
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response
69-
70-
7142
def test_apply_custom_formatting_jinja_template():
7243
"""Tests custom formatting data handler with jinja template dataset returns correct formatted response"""
7344
json_dataset = datasets.load_dataset(
@@ -95,7 +66,7 @@ def test_apply_custom_formatting_jinja_template():
9566
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response
9667

9768

98-
def test_apply_custom_formatting_template_iterable():
69+
def test_apply_custom_formatting_jinja_template_iterable():
9970
"""Tests custom formatting data handler with iterable dataset returns correct formatted response"""
10071
json_dataset = datasets.load_dataset(
10172
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL, streaming=True
@@ -104,7 +75,7 @@ def test_apply_custom_formatting_template_iterable():
10475
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10576
formatted_dataset_field = "formatted_data_field"
10677
formatted_dataset = json_dataset.map(
107-
apply_custom_data_formatting_template,
78+
apply_custom_jinja_template,
10879
fn_kwargs={
10980
"tokenizer": tokenizer,
11081
"formatted_text_column_name": formatted_dataset_field,
@@ -127,25 +98,6 @@ def test_apply_custom_formatting_template_iterable():
12798
assert first_sample[formatted_dataset_field] == expected_response
12899

129100

130-
def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
131-
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
132-
json_dataset = datasets.load_dataset(
133-
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
134-
)
135-
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
136-
formatted_dataset_field = "formatted_data_field"
137-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
138-
with pytest.raises(KeyError):
139-
json_dataset.map(
140-
apply_custom_data_formatting_template,
141-
fn_kwargs={
142-
"tokenizer": tokenizer,
143-
"formatted_text_column_name": formatted_dataset_field,
144-
"template": template,
145-
},
146-
)
147-
148-
149101
@pytest.mark.parametrize(
150102
"template",
151103
[

tests/data/test_data_preprocessing.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
save_dataset_shards,
3636
)
3737
from tests.artifacts.predefined_data_configs import (
38-
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
3938
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
4039
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
4140
DATA_CONFIG_MULTITURN_DATA_YAML,
@@ -887,10 +886,6 @@ def test_process_dataconfig_file_with_streaming_and_multipack_throws_error(
887886
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
888887
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
889888
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
890-
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
891-
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
892-
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
893-
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
894889
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
895890
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
896891
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
@@ -972,15 +967,10 @@ def test_process_dataconfig_file(data_config_path, data_path):
972967
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, True),
973968
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, False),
974969
(
975-
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
970+
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
976971
TWITTER_COMPLAINTS_DATA_JSON,
977972
True,
978973
),
979-
(
980-
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
981-
TWITTER_COMPLAINTS_DATA_JSON,
982-
False,
983-
),
984974
(
985975
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
986976
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,

tuning/data/data_handlers.py

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
# import copy
2222
import logging
23-
import re
2423

2524
# Third Party
2625
from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError
@@ -203,54 +202,6 @@ def add_tokenizer_eos_token(
203202
return {f"{text_column_name}": element[f"{text_column_name}"] + tokenizer.eos_token}
204203

205204

206-
def apply_custom_data_formatting_template(
207-
element: Dict[str, str],
208-
tokenizer: AutoTokenizer,
209-
formatted_text_column_name: str,
210-
template: str,
211-
add_eos_token: bool = True,
212-
**kwargs,
213-
):
214-
"""Function to format datasets with Alpaca style / other templates.
215-
Expects to be run as a HF Map API function.
216-
Args:
217-
element: the HF Dataset element.
218-
tokenizer: Tokenizer to be used for the EOS token, which will be appended
219-
when formatting the data into a single sequence. Defaults to empty.
220-
formatted_text_column_name: Name of the dataset column where formatted
221-
text is to be saved. If doesn't exist a new
222-
column will be created.
223-
template: Template to format data with. Features of Dataset
224-
should be referred to by {{key}}
225-
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
226-
Returns:
227-
Formatted Dataset element by formatting dataset with template+tokenizer.EOS_TOKEN
228-
Saves the result to formatted_text_column_name argument.
229-
"""
230-
231-
if add_eos_token:
232-
template += tokenizer.eos_token
233-
234-
def replace_text(match_obj):
235-
captured_groups = match_obj.groups()
236-
if len(captured_groups) != 1:
237-
raise ValueError(
238-
"Unexpectedly captured multiple groups in template formatting"
239-
)
240-
241-
index_object = captured_groups[0]
242-
if index_object not in element:
243-
raise KeyError("Requested template string is not a valid key in dict")
244-
245-
return str(element[index_object])
246-
247-
return {
248-
f"{formatted_text_column_name}": re.sub(
249-
r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template
250-
)
251-
}
252-
253-
254205
def apply_custom_jinja_template(
255206
element: Dict[str, str],
256207
tokenizer: AutoTokenizer,
@@ -259,7 +210,7 @@ def apply_custom_jinja_template(
259210
add_eos_token: bool = True,
260211
**kwargs,
261212
):
262-
"""Function to format datasets with jinja templates.
213+
"""Function to format datasets with Alpaca style / any other jinja templates.
263214
Expects to be run as a HF Map API function.
264215
Args:
265216
element: the HF Dataset element
@@ -670,12 +621,6 @@ def tokenize_and_apply_chat_template_with_masking(
670621
allows_batching=False,
671622
desc="Adding EOS token to text dataset",
672623
),
673-
"apply_custom_data_formatting_template": DataHandler(
674-
op=apply_custom_data_formatting_template,
675-
handler_type=DataHandlerType.MAP,
676-
allows_batching=False,
677-
desc="Formatting dataset with given formatter template",
678-
),
679624
"apply_custom_jinja_template": DataHandler(
680625
op=apply_custom_jinja_template,
681626
handler_type=DataHandlerType.MAP,

tuning/data/setup_dataprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False):
189189
fn_kwargs["formatted_text_column_name"] = data_args.dataset_text_field
190190
fn_kwargs["template"] = data_args.data_formatter_template
191191
handler = DataHandlerConfig(
192-
"apply_custom_data_formatting_template",
192+
"apply_custom_jinja_template",
193193
arguments={"fn_kwargs": fn_kwargs, "batched": False},
194194
)
195195
return [handler], data_args.dataset_text_field

0 commit comments

Comments
 (0)