Skip to content

Commit faa04d9

Browse files
authored
Merge branch 'main' into compile-cache
2 parents e84f209 + cdb5ff0 commit faa04d9

5 files changed

Lines changed: 240 additions & 17 deletions

File tree

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML = os.path.join(
4141
PREDEFINED_DATA_CONFIGS, "multi_turn_data_with_chat_template_granite_3_1B.yaml"
4242
)
43+
DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER = os.path.join(
44+
PREDEFINED_DATA_CONFIGS, "mt_data_granite_3_1B_tokenize_and_mask_handler.yaml"
45+
)
4346
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT = os.path.join(
4447
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml"
4548
)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
dataprocessor:
2+
type: default
3+
chat_template: |
4+
{%- if messages[0]['role'] == 'system' %}
5+
{%- set system_message = messages[0]['content'] %}
6+
{%- set loop_messages = messages[1:] %}
7+
{%- else %}
8+
{%- set system_message = "Knowledge Cutoff Date: April 2024.\nToday's Date: " + strftime_now('%B %d, %Y') + ".\nYou are Granite, developed by IBM." %}
9+
{%- if tools and documents %}
10+
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.\n\nWrite the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
11+
{%- elif tools %}
12+
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
13+
{%- elif documents %}
14+
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
15+
{%- else %}
16+
{%- set system_message = system_message + " You are a helpful AI assistant." %}
17+
{%- endif %}
18+
{%- if 'citations' in controls and documents %}
19+
{%- set system_message = system_message + '\n\nIn your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
20+
{%- endif %}
21+
{%- if 'hallucinations' in controls and documents %}
22+
{%- set system_message = system_message + '\n\nFinally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
23+
{%- endif %}
24+
{%- set loop_messages = messages %}
25+
{%- endif %}
26+
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>\n' }}
27+
{%- if tools %}
28+
{{- '<|start_of_role|>tools<|end_of_role|>' }}
29+
{{- tools | tojson(indent=4) }}
30+
{{- '<|end_of_text|>\n' }}
31+
{%- endif %}
32+
{%- if documents %}
33+
{{- '<|start_of_role|>documents<|end_of_role|>' }}
34+
{%- for document in documents %}
35+
{{- 'Document ' + loop.index0 | string + '\n' }}
36+
{{- document['text'] }}
37+
{%- if not loop.last %}
38+
{{- '\n\n'}}
39+
{%- endif%}
40+
{%- endfor %}
41+
{{- '<|end_of_text|>\n' }}
42+
{%- endif %}
43+
{%- for message in loop_messages %}
44+
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}
45+
{%- if loop.last and add_generation_prompt %}
46+
{{- '<|start_of_role|>assistant' }}
47+
{%- if controls %}
48+
{{- ' ' + controls | tojson()}}
49+
{%- endif %}
50+
{{- '<|end_of_role|>' }}
51+
{%- endif %}
52+
{%- endfor %}
53+
datasets:
54+
- name: dataset_1
55+
data_paths:
56+
- "FILE_PATH"
57+
data_handlers:
58+
- name: tokenize_and_apply_chat_template_with_masking
59+
arguments:
60+
remove_columns: all
61+
fn_kwargs:
62+
max_seq_length: 1024
63+
conversation_column: "messages"
64+
- name: dataset_2
65+
data_paths:
66+
- "FILE_PATH"
67+
data_handlers:
68+
- name: tokenize_and_apply_chat_template_with_masking
69+
arguments:
70+
remove_columns: all
71+
fn_kwargs:
72+
max_seq_length: 1024
73+
conversation_column: "messages"
74+
- name: dataset_3
75+
data_paths:
76+
- "FILE_PATH"
77+
data_handlers:
78+
- name: tokenize_and_apply_chat_template_with_masking
79+
arguments:
80+
remove_columns: all
81+
fn_kwargs:
82+
max_seq_length: 1024
83+
conversation_column: "messages"

tests/test_sft_trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from tests.artifacts.predefined_data_configs import (
4040
DATA_CONFIG_DUPLICATE_COLUMNS,
4141
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
42+
DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER,
4243
DATA_CONFIG_MULTITURN_DATA_YAML,
4344
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML,
4445
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
@@ -1258,6 +1259,14 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
12581259
],
12591260
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML,
12601261
),
1262+
(
1263+
[
1264+
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
1265+
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
1266+
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
1267+
],
1268+
DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER,
1269+
),
12611270
],
12621271
)
12631272
def test_run_chat_style_ft_using_dataconfig_for_chat_template(
@@ -1768,7 +1777,7 @@ def test_pretokenized_dataset_bad_args(dataset_text_field, response_template):
17681777
data_args = copy.deepcopy(DATA_ARGS)
17691778
data_args.dataset_text_field = dataset_text_field
17701779
data_args.response_template = response_template
1771-
data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL
1780+
data_args.training_data_path = TWITTER_COMPLAINTS_TOKENIZED_JSON
17721781
# We should raise an error since we should not have a dataset text
17731782
# field or a response template if we have pretokenized data
17741783
with pytest.raises(ValueError):

tuning/data/data_handlers.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError
2525
from jinja2.sandbox import SandboxedEnvironment, SecurityError
2626
from transformers import AutoTokenizer
27+
import torch
2728

2829
# Local
2930
from tuning.utils.config_utils import process_jinja_placeholders
@@ -381,6 +382,128 @@ def skip_large_text(element: Dict[str, str], column_name: str, max_length: int):
381382
return len(element[column_name]) < max_length
382383

383384

385+
def tokenize_and_apply_chat_template_with_masking(
386+
element: Dict[str, str],
387+
tokenizer: AutoTokenizer,
388+
max_seq_length: int = None,
389+
conversation_column: str = "messages",
390+
**kwargs,
391+
):
392+
"""Function to apply chat template to the dataset elements and
393+
perform masking to ensure model is trained only on completions.
394+
Assumes the dataset is modelled according to ChatML style format
395+
like,
396+
{ messages: {'role': 'user', 'content': 'blah'}
397+
398+
Tokenizes the dataset and returns a tokenized element.
399+
Requires that max_seq_length is passed to ensure truncation of
400+
extra large samples. If samples are to be skipped truncated please
401+
use filter data handler before using this to ensure skipping
402+
of samples.
403+
404+
Expects to be run as a HF Map function.
405+
Ensures that element contains `input_ids`, `labels` and
406+
`attention_mask`
407+
If used with `remove_columns=all` the dataset can be used
408+
directly to train.
409+
Args:
410+
element: the HF Dataset samples
411+
tokenizer: Tokenizer to be used.
412+
max_seq_length: Max seq length of the tokens allowed.
413+
Required argument.
414+
conversation_column: Name of the column which contains conversations
415+
Typically `messages`
416+
kwargs: Unused by this function.
417+
Returns:
418+
Tokenized element which contains `input_ids` `labels` and `attention_mask`
419+
with labels properly masked to train only on completions.
420+
"""
421+
422+
# This function is taken from OpenInstruct
423+
# https://github.com/allenai/open-instruct/blob/\
424+
# d208aa371976a09152f61991951e981573e7582f/open_instruct/\
425+
# dataset_transformation.py#L632
426+
427+
messages = element[conversation_column]
428+
429+
if len(messages) == 0:
430+
raise ValueError(
431+
f"Contents of the column {conversation_column} must not be empty."
432+
)
433+
434+
# Tokenize the whole sample
435+
input_ids = tokenizer.apply_chat_template(
436+
conversation=messages,
437+
tokenize=True,
438+
padding=False,
439+
return_tensors="pt",
440+
truncation=True,
441+
max_length=max_seq_length,
442+
add_generation_prompt=False,
443+
)
444+
445+
# clone labels from input ids
446+
labels = input_ids.clone()
447+
448+
# mask the non-assistant part for avoiding loss
449+
for message_idx, message in enumerate(messages):
450+
if message["role"] != "assistant":
451+
# we calculate the start index of this non-assistant message
452+
if message_idx == 0:
453+
message_start_idx = 0
454+
else:
455+
message_start_idx = tokenizer.apply_chat_template(
456+
conversation=messages[
457+
:message_idx
458+
], # here marks the end of the previous messages
459+
tokenize=True,
460+
padding=False,
461+
return_tensors="pt",
462+
truncation=True,
463+
max_length=max_seq_length,
464+
add_generation_prompt=False,
465+
).shape[1]
466+
# next, we calculate the end index of this non-assistant message
467+
if (
468+
message_idx < len(messages) - 1
469+
and messages[message_idx + 1]["role"] == "assistant"
470+
):
471+
# for intermediate messages that follow with an assistant message,
472+
# we need to set `add_generation_prompt=True` to avoid the assistant
473+
# generation prefix being included in the loss (e.g., `<|assistant|>`)
474+
message_end_idx = tokenizer.apply_chat_template(
475+
conversation=messages[: message_idx + 1],
476+
tokenize=True,
477+
return_tensors="pt",
478+
padding=False,
479+
truncation=True,
480+
max_length=max_seq_length,
481+
add_generation_prompt=True,
482+
).shape[1]
483+
else:
484+
# for the last message or the message that doesn't follow with
485+
# an assistant message, we don't need to add the assistant generation prefix
486+
message_end_idx = tokenizer.apply_chat_template(
487+
conversation=messages[: message_idx + 1],
488+
tokenize=True,
489+
return_tensors="pt",
490+
padding=False,
491+
truncation=True,
492+
max_length=max_seq_length,
493+
add_generation_prompt=False,
494+
).shape[1]
495+
# set the label to -100 for the non-assistant part
496+
labels[:, message_start_idx:message_end_idx] = -100
497+
if max_seq_length and message_end_idx >= max_seq_length:
498+
break
499+
attention_mask = torch.ones_like(input_ids)
500+
return {
501+
"input_ids": input_ids.flatten(),
502+
"labels": labels.flatten(),
503+
"attention_mask": attention_mask.flatten(),
504+
}
505+
506+
384507
AVAILABLE_DATA_HANDLERS = {
385508
"tokenize_and_apply_input_masking": DataHandler(
386509
op=tokenize_and_apply_input_masking,
@@ -407,6 +530,11 @@ def skip_large_text(element: Dict[str, str], column_name: str, max_length: int):
407530
handler_type=DataHandlerType.MAP,
408531
allows_batching=False,
409532
),
533+
"tokenize_and_apply_chat_template_with_masking": DataHandler(
534+
op=tokenize_and_apply_chat_template_with_masking,
535+
handler_type=DataHandlerType.MAP,
536+
allows_batching=False,
537+
),
410538
"duplicate_columns": DataHandler(
411539
op=duplicate_columns,
412540
handler_type=DataHandlerType.MAP,

tuning/data/data_preprocessing_utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ def get_data_collator(
6565
# packing for non tokenized dataset doesn't require a collator with SFTrainer.
6666
return None
6767

68+
if is_padding_free:
69+
# when packing is false but padding_free is used and
70+
# no response template is used then its a pretrained scenario.
71+
# Current plugin in fms-acceleration is compatible with
72+
# `DataCollatorForSeq2Seq` collator hence we use this.
73+
return DataCollatorForSeq2Seq(
74+
tokenizer=tokenizer, padding=False, max_length=max_seq_length
75+
)
76+
77+
if is_traindata_tokenized:
78+
# Note that this automatically pads labels with -100
79+
# TODO check if this is sufficient for preprocessed
80+
return DataCollatorForSeq2Seq(
81+
tokenizer=tokenizer, padding=True, max_length=max_seq_length
82+
)
83+
6884
# TODO: near term - how response template ids are parsed out needs to be cleaned.
6985
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
7086
# otherwise template is not found. We will create issue to clean this out after we discuss
@@ -88,22 +104,6 @@ def get_data_collator(
88104
ignore_index=configs.IGNORE_INDEX,
89105
)
90106

91-
if is_padding_free:
92-
# when packing is false but padding_free is used and
93-
# no response template is used then its a pretrained scenario.
94-
# Current plugin in fms-acceleration is compatible with
95-
# `DataCollatorForSeq2Seq` collator hence we use this.
96-
return DataCollatorForSeq2Seq(
97-
tokenizer=tokenizer, padding=False, max_length=max_seq_length
98-
)
99-
100-
if is_traindata_tokenized:
101-
# Note that this automatically pads labels with -100
102-
# TODO check if this is sufficient for preprocessed
103-
return DataCollatorForSeq2Seq(
104-
tokenizer=tokenizer, padding=True, max_length=max_seq_length
105-
)
106-
107107
raise ValueError(
108108
"Could not pick a data collator. Please refer to supported data formats"
109109
)

0 commit comments

Comments
 (0)