Skip to content

Commit 73d6eef

Browse files
committed
Clean up datahandler processing.
Rename data handlers to match HF API names. Remove rename/retain features and make them data handlers under the new framework. Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent dc77c63 commit 73d6eef

12 files changed

Lines changed: 368 additions & 214 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ coverage*.xml
88
dist
99
htmlcov
1010
test
11+
error.log
1112

1213
# IDEs
1314
.vscode/

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@
6464
DATA_CONFIG_DUPLICATE_COLUMNS = os.path.join(
6565
PREDEFINED_DATA_CONFIGS, "duplicate_columns.yaml"
6666
)
67-
DATA_CONFIG_RENAME_RETAIN_COLUMNS = os.path.join(
68-
PREDEFINED_DATA_CONFIGS, "rename_retain_columns.yaml"
67+
DATA_CONFIG_RENAME_SELECT_COLUMNS = os.path.join(
68+
PREDEFINED_DATA_CONFIGS, "rename_select_columns.yaml"
6969
)
7070
DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER = os.path.join(
7171
PREDEFINED_DATA_CONFIGS, "tokenize_using_handler_and_train.yaml"
7272
)
73-
DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER = os.path.join(
74-
PREDEFINED_DATA_CONFIGS, "skip_large_text_data_handler_template.yaml"
73+
DATA_CONFIG_SKIP_LARGE_COLUMNS_HANDLER = os.path.join(
74+
PREDEFINED_DATA_CONFIGS, "skip_large_columns_data_handler_template.yaml"
7575
)

tests/artifacts/predefined_data_configs/rename_retain_columns.yaml renamed to tests/artifacts/predefined_data_configs/rename_select_columns.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@ dataprocessor:
22
type: default
33
datasets:
44
- name: text_dataset_input_output_masking
5-
rename_columns:
6-
"input" : "instruction"
7-
"output" : "response"
8-
retain_columns:
9-
- "instruction"
10-
- "response"
115
data_paths:
126
- "FILE_PATH"
137
data_handlers:
8+
- name: rename_columns
9+
arguments:
10+
column_mapping:
11+
"input" : "instruction"
12+
"output" : "response"
13+
- name: select_columns
14+
arguments:
15+
column_names:
16+
- "instruction"
17+
- "response"
1418
- name: tokenize_and_apply_input_masking
1519
arguments:
1620
remove_columns: all

tests/artifacts/predefined_data_configs/skip_large_text_data_handler_template.yaml renamed to tests/artifacts/predefined_data_configs/skip_large_columns_data_handler_template.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
dataprocessor:
22
type: default
33
datasets:
4-
- name: pre_tokenized
4+
- name: non_tokenized
55
data_paths:
66
- "FILE_PATH"
77
data_handlers:
@@ -17,7 +17,7 @@ datasets:
1717
fn_kwargs:
1818
old_column: "input_ids"
1919
new_column: "labels"
20-
- name: skip_large_text
20+
- name: skip_large_columns
2121
arguments:
2222
fn_kwargs:
2323
column_name: "input_ids"

tests/data/test_data_handlers.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,9 @@
3535
apply_custom_jinja_template,
3636
combine_sequence,
3737
duplicate_columns,
38-
skip_large_text,
38+
skip_large_columns,
3939
tokenize,
4040
)
41-
from tuning.data.setup_dataprocessor import is_pretokenized_dataset
4241

4342

4443
def test_apply_custom_formatting_template():
@@ -287,19 +286,19 @@ def test_tokenizer_data_handler_tokenizes():
287286
("not_existing", "not_existing"),
288287
],
289288
)
290-
def test_skip_large_text_handler_throws_error_on_bad_args(column_name, max_length):
291-
"Ensure that skip large text handler throws error on bad arguments"
289+
def test_skip_large_columns_handler_throws_error_on_bad_args(column_name, max_length):
290+
"Ensure that skip large columns handler throws error on bad arguments"
292291
d = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA_JSONL)
293292
fn_kwargs = {}
294293
fn_kwargs["column_name"] = column_name
295294
fn_kwargs["max_length"] = max_length
296295

297296
with pytest.raises(ValueError):
298-
filtered = d.filter(skip_large_text, fn_kwargs=fn_kwargs)
297+
filtered = d.filter(skip_large_columns, fn_kwargs=fn_kwargs)
299298

300299

301-
def test_skip_large_text_handler():
302-
"Ensure that skip large text handler skips dataset as intended"
300+
def test_skip_large_columns_handler():
301+
"Ensure that skip large columns handler skips dataset as intended"
303302

304303
def test_dataset_generator():
305304
for i in range(0, 100):
@@ -308,7 +307,7 @@ def test_dataset_generator():
308307
d = Dataset.from_generator(test_dataset_generator)
309308
fn_kwargs = {}
310309
fn_kwargs["column_name"] = "input"
311-
fn_kwargs["max_length"] = 61
310+
fn_kwargs["max_length"] = 60
312311

313-
filtered = d.filter(skip_large_text, fn_kwargs=fn_kwargs)
312+
filtered = d.filter(skip_large_columns, fn_kwargs=fn_kwargs)
314313
assert len(filtered) == 60

tests/data/test_data_preprocessing.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
3838
DATA_CONFIG_MULTITURN_DATA_YAML,
3939
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
40-
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
40+
DATA_CONFIG_RENAME_SELECT_COLUMNS,
4141
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
4242
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT,
4343
DATA_CONFIG_YAML_STREAMING_PRETOKENIZED,
@@ -70,7 +70,11 @@
7070
from tuning.config import configs
7171
from tuning.config.acceleration_configs import AttentionAndDistributedPackingConfig
7272
from tuning.data.collators import VisionDataCollator
73-
from tuning.data.data_config import DataPreProcessorConfig, DataSetConfig
73+
from tuning.data.data_config import (
74+
DataHandlerConfig,
75+
DataPreProcessorConfig,
76+
DataSetConfig,
77+
)
7478
from tuning.data.data_preprocessing_utils import get_data_collator
7579
from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor
7680
from tuning.data.setup_dataprocessor import (
@@ -1674,33 +1678,33 @@ def test_process_dataset_configs_with_sampling_error(
16741678

16751679

16761680
@pytest.mark.parametrize(
1677-
"datafile, rename, retain, final, datasetconfigname",
1681+
"datafile, rename, select, final, datasetconfigname",
16781682
[
16791683
(
16801684
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
16811685
{"input": "instruction", "output": "response"},
16821686
None,
16831687
["ID", "Label", "instruction", "response"],
1684-
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
1688+
DATA_CONFIG_RENAME_SELECT_COLUMNS,
16851689
),
16861690
(
16871691
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
16881692
None,
16891693
["ID", "input", "output"],
16901694
["ID", "input", "output"],
1691-
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
1695+
DATA_CONFIG_RENAME_SELECT_COLUMNS,
16921696
),
16931697
(
16941698
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
16951699
{"input": "instruction", "output": "response"},
16961700
["Label", "instruction", "response"],
16971701
["Label", "instruction", "response"],
1698-
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
1702+
DATA_CONFIG_RENAME_SELECT_COLUMNS,
16991703
),
17001704
],
17011705
)
1702-
def test_rename_and_retain_dataset_columns(
1703-
datafile, rename, retain, final, datasetconfigname
1706+
def test_rename_and_select_dataset_columns(
1707+
datafile, rename, select, final, datasetconfigname
17041708
):
17051709
"""Test process_dataset_configs for expected output."""
17061710
dataprocessor_config = DataPreProcessorConfig()
@@ -1709,12 +1713,23 @@ def test_rename_and_retain_dataset_columns(
17091713
processor_config=dataprocessor_config,
17101714
tokenizer=tokenizer,
17111715
)
1716+
1717+
handlers = []
1718+
if rename:
1719+
handlers.append(
1720+
DataHandlerConfig(
1721+
name="rename_columns", arguments={"column_mapping": rename}
1722+
)
1723+
)
1724+
if select:
1725+
handlers.append(
1726+
DataHandlerConfig(name="select_columns", arguments={"column_names": select})
1727+
)
1728+
data_paths = [datafile]
1729+
17121730
datasetconfig = [
17131731
DataSetConfig(
1714-
name=datasetconfigname,
1715-
data_paths=[datafile],
1716-
rename_columns=rename,
1717-
retain_columns=retain,
1732+
name=datasetconfigname, data_paths=data_paths, data_handlers=handlers
17181733
)
17191734
]
17201735
train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig)

tests/test_sft_trainer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER,
4444
DATA_CONFIG_MULTITURN_DATA_YAML,
4545
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML,
46-
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
47-
DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER,
46+
DATA_CONFIG_RENAME_SELECT_COLUMNS,
47+
DATA_CONFIG_SKIP_LARGE_COLUMNS_HANDLER,
4848
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
4949
DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER,
5050
DATA_CONFIG_VALID_BASE64_CHAT_TEMPLATE,
@@ -925,7 +925,7 @@ def test_run_causallm_ft_and_inference_streaming(datasetconfigname, datafiles):
925925
),
926926
(
927927
[TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON],
928-
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
928+
DATA_CONFIG_RENAME_SELECT_COLUMNS,
929929
),
930930
],
931931
)
@@ -1064,8 +1064,8 @@ def test_run_training_with_data_tokenized_using_tokenizer_handler():
10641064
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
10651065

10661066

1067-
def test_run_training_with_skip_large_text_handler():
1068-
"""Ensure that we can train succesfully after using skip large text handler."""
1067+
def test_run_training_with_skip_large_column_handler():
1068+
"""Ensure that we can train succesfully after using skip large column handler."""
10691069
with tempfile.TemporaryDirectory() as tempdir:
10701070

10711071
data_args = copy.deepcopy(DATA_ARGS)
@@ -1074,8 +1074,8 @@ def test_run_training_with_skip_large_text_handler():
10741074
data_args.response_template = None
10751075
data_args.training_data_path = None
10761076

1077-
dataconfigfile = DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER
1078-
datapath = TWITTER_COMPLAINTS_TOKENIZED_JSON
1077+
dataconfigfile = DATA_CONFIG_SKIP_LARGE_COLUMNS_HANDLER
1078+
datapath = TWITTER_COMPLAINTS_DATA_JSONL
10791079

10801080
# add data_paths in data_config file
10811081
with tempfile.NamedTemporaryFile(

tuning/config/configs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ class DataArguments:
7272
dataset_text_field: str = field(
7373
default=None,
7474
metadata={
75-
"help": "Training dataset text field containing single sequence. \
75+
"help": "[DEPRECATED] "
76+
"Use text_column_name to specify this argument going forward\n"\
77+
"Training dataset text field containing single sequence. \
7678
Either the dataset_text_field \
7779
or data_formatter_template need to be supplied. \
7880
For running vision language model tuning pass the column name for text data."
@@ -85,6 +87,14 @@ class DataArguments:
8587
Used as key to point multi-turn data field."
8688
},
8789
)
90+
text_column_name : str = field(
91+
default=None,
92+
metadata={
93+
"help": "Training dataset text column name containing single sequence. \
94+
Either the text_column_name \
95+
or data_formatter_template need to be supplied."
96+
},
97+
)
8898
validation_data_path: str = field(
8999
default=None,
90100
metadata={"help": "Path to the validation data in JSON/JSONL format."},
@@ -157,6 +167,9 @@ def unescape(s):
157167
self.response_template = unescape(self.response_template)
158168
self.instruction_template = unescape(self.instruction_template)
159169

170+
# Initialise deprecated field
171+
if self.dataset_text_field:
172+
self.text_column_name = self.dataset_text_field
160173

161174
@dataclass
162175
class TrainingArguments(transformers.TrainingArguments):

tuning/data/data_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class DataSetConfig:
3737
data_paths: List[str]
3838
builder: Optional[str] = None # Referring to Hugging Face dataset builder
3939
sampling: Optional[float] = None
40-
rename_columns: Optional[Dict] = None
41-
retain_columns: Optional[List] = None
4240
data_handlers: Optional[List[DataHandlerConfig]] = None
4341

4442

0 commit comments

Comments
 (0)