Skip to content

Commit 49e40e9

Browse files
feat: expose conversation_column in data_args
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent 37fda79 commit 49e40e9

6 files changed

Lines changed: 97 additions & 0 deletions

File tree

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,27 @@ For the [granite model above](https://huggingface.co/ibm-granite/granite-3.0-8b-
182182
--response_template "<|start_of_role|>assistant<|end_of_role|>"
183183
```
184184

185+
186+
When working with multi-turn datasets, it's often necessary to extract specific fields from the data depending on the format. For example, in many multi-turn datasets, conversations may be stored under a dedicated key (e.g., `conversations`), and you may only need the content of that key for processing.
187+
188+
```
189+
{
190+
"conversations": [
191+
{"content": "You are an AI...", "role": "system"},
192+
{"content": "Look up a word...", "role": "user"},
193+
{"content": "A word that rhymes is 'mist'", "role": "assistant"}
194+
],
195+
"group": "lab_extension",
196+
"dataset": "base/full-extension",
197+
"metadata": "{\"num_turns\": 2}"
198+
}
199+
200+
```
201+
To extract and use the conversations field, pass the following flag when running:
202+
```
203+
--conversation_column "conversations"
204+
```
205+
185206
The code internally uses [`DataCollatorForCompletionOnlyLM`](https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L93) to perform masking of text ensuring model learns only on the `assistant` responses for both single and multi turn chat.
186207

187208
Depending on various scenarios users might need to decide on how to use chat template with their data or which chat template to use for their use case.

tests/artifacts/testdata/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
)
6969
CHAT_DATA_SINGLE_TURN = os.path.join(JSONL_DATA_DIR, "single_turn_chat.jsonl")
7070
CHAT_DATA_MULTI_TURN = os.path.join(JSONL_DATA_DIR, "multi_turn_chat.jsonl")
71+
CHAT_DATA_MULTI_TURN_CONVERSATIONS = os.path.join(
72+
JSONL_DATA_DIR, "multi_turn_chat_conversations.jsonl"
73+
)
7174
CHAT_DATA_MULTI_TURN_GRANITE_3_1B = os.path.join(
7275
JSONL_DATA_DIR, "multi_turn_chat_granite_instruct.jsonl"
7376
)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Using the word \"grace\", come up with a word that rhymes and has the same number of syllables\n<nopace>", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"grace\" and has the same number of syllables:\n1\\. Space", "role": "assistant"}, {"content": "Using the word \"grace\", come up with a word that rhymes and has the same number of syllables\n<nopace>", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"grace\" and has the same number of syllables:\n1\\. Space", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
2+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Tell me one word that rhymes and has the opposite meaning of \"open\".", "role": "user"}, {"content": "Sure! Here's a word that rhymes with \"open\" and has the opposite meaning:\n1\\. Shut", "role": "assistant"}, {"content": "Tell me one word that rhymes and has the opposite meaning of \"open\".", "role": "user"}, {"content": "Sure! Here's a word that rhymes with \"open\" and has the opposite meaning:\n1\\. Shut", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
3+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "For the word \"dream\", give an example of a word that rhymes with it and its synonym.", "role": "user"}, {"content": "Here's an example for \"dream\" that includes a word that rhymes with it and a synonym:\n1\\. Word that rhymes with \"dream\": \"beam\"\nSynonym: \"ideal\"", "role": "assistant"}, {"content": "For the word \"dream\", give an example of a word that rhymes with it and its synonym.", "role": "user"}, {"content": "Here's an example for \"dream\" that includes a word that rhymes with it and a synonym:\n1\\. Word that rhymes with \"dream\": \"beam\"\nSynonym: \"ideal\"", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
4+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Can you think of a word that rhymes with tree?", "role": "user"}, {"content": "Of course, I can help! Here's one word that rhymes with \"tree\":\n1\\. Three", "role": "assistant"}, {"content": "Can you think of a word that rhymes with tree?", "role": "user"}, {"content": "Of course, I can help! Here's one word that rhymes with \"tree\":\n1\\. Three", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
5+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Look up a word that rhymes with exist", "role": "user"}, {"content": "I found a word that rhymes with \"exist\":\n1\\. Mist", "role": "assistant"}, {"content": "Look up a word that rhymes with exist", "role": "user"}, {"content": "I found a word that rhymes with \"exist\":\n1\\. Mist", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
6+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Provide five words that rhyme with \"green\".", "role": "user"}, {"content": "Here are five words that rhyme with \"green\":\n1\\. Bean\n2\\. Clean\n3\\. Keen\n4\\. Mean\n5\\. Seen", "role": "assistant"}, {"content": "Provide five words that rhyme with \"green\".", "role": "user"}, {"content": "Here are five words that rhyme with \"green\":\n1\\. Bean\n2\\. Clean\n3\\. Keen\n4\\. Mean\n5\\. Seen", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
7+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "What are three words that rhyme with \"peace\"?", "role": "user"}, {"content": "Here are three words that rhyme with \"peace\":\n1\\. Breeze\n2\\. Pieces\n3\\. Increase", "role": "assistant"}, {"content": "What are three words that rhyme with \"peace\"?", "role": "user"}, {"content": "Here are three words that rhyme with \"peace\":\n1\\. Breeze\n2\\. Pieces\n3\\. Increase", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
8+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Can you name a word that rhymes with high and starts with the letter \"t\"?", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"high\" and starts with \"t\":\n1\\. Sky", "role": "assistant"}, {"content": "Can you name a word that rhymes with high and starts with the letter \"t\"?", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"high\" and starts with \"t\":\n1\\. Sky", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
9+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Identify the word that rhymes with \"love\" and has the same vowel sounds in the same order as the word \"busy\".", "role": "user"}, {"content": "The word that rhymes with \"love\" and has the same vowel sounds in the same order as the word \"busy\" is:\n1\\. \"Shoe\".", "role": "assistant"}, {"content": "Identify the word that rhymes with \"love\" and has the same vowel sounds in the same order as the word \"busy\".", "role": "user"}, {"content": "The word that rhymes with \"love\" and has the same vowel sounds in the same order as the word \"busy\" is:\n1\\. \"Shoe\".", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
10+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Provide two rhyming words for the word \"love\".\n<no above>", "role": "user"}, {"content": "Here are two rhyming words for \"love\":\n1\\. Dove\n2\\. Glove", "role": "assistant"}, {"content": "Provide two rhyming words for the word \"love\".\n<no above>", "role": "user"}, {"content": "Here are two rhyming words for \"love\":\n1\\. Dove\n2\\. Glove", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}

tests/test_sft_trainer.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
)
5151
from tests.artifacts.testdata import (
5252
CHAT_DATA_MULTI_TURN,
53+
CHAT_DATA_MULTI_TURN_CONVERSATIONS,
5354
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
5455
CHAT_DATA_SINGLE_TURN,
5556
CUSTOM_TOKENIZER_TINYLLAMA,
@@ -1138,6 +1139,59 @@ def test_run_chat_style_ft(dataset_path):
11381139
assert 'Provide two rhyming words for the word "love"' in output_inference
11391140

11401141

1142+
def test_run_chat_style_ft_dataset_conversation_field():
1143+
"""Check if we can perform an e2e run with chat template and multi turn chat training."""
1144+
with tempfile.TemporaryDirectory() as tempdir:
1145+
1146+
data_args = copy.deepcopy(DATA_ARGS)
1147+
data_args.training_data_path = CHAT_DATA_MULTI_TURN_CONVERSATIONS
1148+
1149+
# sampled chat template from granite3.1 instruct model
1150+
data_args.chat_template = "{%- for message in messages %}\
1151+
{%- if message['role'] == 'system' %}\
1152+
{{- '<|system|>\n' + message['content'] + '\n' }}\
1153+
{%- elif message['role'] == 'user' %}\
1154+
{{- '<|user|>\n' + message['content'] + '\n' }}\
1155+
{%- elif message['role'] == 'assistant' %}\
1156+
{%- if not loop.last %}\
1157+
{{- '<|assistant|>\n' + message['content'] + eos_token + '\n' }}\
1158+
{%- else %}\
1159+
{{- '<|assistant|>\n' + message['content'] + eos_token }}\
1160+
{%- endif %}\
1161+
{%- endif %}\
1162+
{%- if loop.last and add_generation_prompt %}\
1163+
{{- '<|assistant|>\n' }}\
1164+
{%- endif %}\
1165+
{%- endfor %}"
1166+
data_args.response_template = "<|assistant|>"
1167+
data_args.instruction_template = "<|user|>"
1168+
data_args.conversation_column = "conversations"
1169+
1170+
model_args = copy.deepcopy(MODEL_ARGS)
1171+
model_args.tokenizer_name_or_path = CUSTOM_TOKENIZER_TINYLLAMA
1172+
1173+
train_args = copy.deepcopy(TRAIN_ARGS)
1174+
train_args.output_dir = tempdir
1175+
1176+
sft_trainer.train(model_args, data_args, train_args)
1177+
1178+
# validate the configs
1179+
_validate_training(tempdir)
1180+
checkpoint_path = _get_checkpoint_path(tempdir)
1181+
1182+
# Load the model
1183+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
1184+
1185+
# Run inference on the text
1186+
output_inference = loaded_model.run(
1187+
'<|user|>\nProvide two rhyming words for the word "love"\n\
1188+
<nopace></s><|assistant|>',
1189+
max_new_tokens=50,
1190+
)
1191+
assert len(output_inference) > 0
1192+
assert 'Provide two rhyming words for the word "love"' in output_inference
1193+
1194+
11411195
def test_run_chat_style_add_special_tokens_ft():
11421196
"""Test to check an e2e multi turn chat training by adding special tokens via command line."""
11431197
with tempfile.TemporaryDirectory() as tempdir:

tuning/config/configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ class DataArguments:
7777
or data_formatter_template need to be supplied."
7878
},
7979
)
80+
dataset_conversation_field: str = field(
81+
default=None,
82+
metadata={
83+
"help": "Training dataset text field containing multi-turn chat data. \
84+
Used as key to point multi-turn data field."
85+
},
86+
)
8087
validation_data_path: str = field(
8188
default=None,
8289
metadata={"help": "Path to the validation data in JSON/JSONL format."},

tuning/data/setup_dataprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def _get_chat_dataset_handlers(data_args, tokenizer_kwargs):
190190
fn_kwargs = {}
191191
fn_kwargs["dataset_text_field"] = data_args.dataset_text_field
192192
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs
193+
if data_args.dataset_conversation_field is not None:
194+
fn_kwargs["conversation_column"] = data_args.dataset_conversation_field
193195

194196
kwargs = {"fn_kwargs": fn_kwargs, "batched": False, "remove_columns": "all"}
195197

0 commit comments

Comments
 (0)