-
Notifications
You must be signed in to change notification settings - Fork 107
Expand file tree
/
Copy pathcompatibility_functions.py
More file actions
507 lines (443 loc) · 23 KB
/
compatibility_functions.py
File metadata and controls
507 lines (443 loc) · 23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
import warnings
import numpy as np
import torch
import datasets
from datasets import IterableDataset, Dataset, Value
from typing import Any, Literal, Optional, Union
from transformers import AutoModel, AutoTokenizer, DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizer
from dataclasses import dataclass
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
"""
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
when they do not come from the assistant. This ensure that the loss is only calculated on the completion made by
the assistant.
Args:
response_template (`Union[str, list[int]]`):
the template form that indicates the start of the response, typically something like '### Response:\n'. It
can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
differently if it does not have proper context.
instruction_template (`Union[str, list[int]]`):
the template form that indicates the start of the human instruction, typically something like '###
Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids.
mlm (`bool`, *optional*, defaults to `False`): Whether to use masked language modeling in the underlying
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
for flexibility and backwards-compatibility.
ignore_index (`int`, *optional*, defaults to `-100`):
The index to use to ignore the initial tokens with
"""
def __init__(
self,
response_template: Union[str, list[int]],
instruction_template: Optional[Union[str, list[int]]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
padding_free: bool = False,
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)
warnings.warn(
"This class is deprecated and will be removed in version 0.20.0. To train on completion only, please use "
"the parameter `completion_only_loss` of `SFTConfig` instead.",
DeprecationWarning,
)
self.instruction_template = instruction_template
if isinstance(instruction_template, str):
# The user provides a string, must tokenize
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
else:
# The user already provides the token ids
self.instruction_token_ids = instruction_template
self.response_template = response_template
if isinstance(response_template, str):
# The user provides a string, must tokenize
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
else:
# The user already provides the token ids
self.response_token_ids = response_template
if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
warnings.warn(
"The pad_token_id and eos_token_id values of this tokenizer are identical. "
"If you are planning for multi-turn training, "
"it can result in the model continuously generating questions and answers without eos token. "
"To avoid this, set the pad_token_id to a different value.",
UserWarning,
)
self.ignore_index = ignore_index
self.padding_free = padding_free
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
batch = super().torch_call(examples)
if self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
if (
self.response_token_ids
== batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
):
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
warnings.warn(
f"Could not find response key `{self.response_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
"calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
else:
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
# Make pytorch loss function ignore all tokens up through the end of the response key
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
else:
for i in range(len(examples)):
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
# find the indexes of the start of a response.
if (
self.response_token_ids
== batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
):
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{self.response_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
"calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
human_token_ids = self.instruction_token_ids
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
# find the indexes of the start of a human answer.
if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
"calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
if (
len(human_token_ids_idxs) > 0
and len(response_token_ids_idxs) > 0
and human_token_ids_idxs[0] > response_token_ids_idxs[0]
):
human_token_ids_idxs = [0] + human_token_ids_idxs
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
# Make pytorch loss function ignore all non response tokens
if idx != 0:
batch["labels"][i, start:end] = self.ignore_index
else:
batch["labels"][i, :end] = self.ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
if self.padding_free:
# remove padding, `attention_mask` and add `position_ids`
attn_mask = batch.pop("attention_mask")
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index
# Calculate cumulative sequence lengths for queries and keys to prevent graph breaks during further computations.
flattened_position_ids = batch["position_ids"].flatten()
indices_q = torch.arange(
flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32
)
batch["cu_seq_lens_q"] = torch.cat(
(
indices_q[flattened_position_ids == 0],
torch.tensor(
flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32
),
)
).unsqueeze(0)
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]
# Determine maximum sequence lengths to prevent graph breaks during further computations.
batch["max_length_k"] = torch.tensor([flattened_position_ids.max().item() + 1])
batch["max_length_q"] = batch["max_length_k"]
return batch
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files. The dataset also formats
the text before tokenization with a specific format that is provided by the user.
Args:
tokenizer (`transformers.PreTrainedTokenizer`):
The processor used for processing the data.
dataset (`dataset.Dataset`):
Dataset with text files.
dataset_text_field (`str` or `None`, *optional*, defaults to `None`):
Name of the field in the dataset that contains the text. Only one of `dataset_text_field` and
`formatting_func` should be provided.
formatting_func (`Callable`, *optional*):
Function that formats the text before tokenization. Usually it is recommended to follow a certain pattern
such as `"### Question: {question} ### Answer: {answer}"`. Only one of `dataset_text_field` and
`formatting_func` should be provided.
infinite (`bool`, *optional*, defaults to `False`):
If True the iterator is reset after dataset reaches end else stops.
seq_length (`int`, *optional*, defaults to `1024`):
Length of token sequences to return.
num_of_sequences (`int`, *optional*, defaults to `1024`):
Number of token sequences to keep in buffer.
chars_per_token (`int`, *optional*, defaults to `3.6`):
Number of characters per token used to estimate number of tokens in text buffer.
eos_token_id (`int`, *optional*, defaults to `0`):
Id of the end of sequence token if the passed tokenizer does not have an EOS token.
shuffle (`bool`, *optional*, defaults to `True`)
Shuffle the examples before they are returned
append_concat_token (`bool`, *optional*, defaults to `True`)
If true, appends `eos_token_id` at the end of each sample being packed.
add_special_tokens (`bool`, *optional*, defaults to `True`)
If true, tokenizers adds special tokens to each sample being packed.
"""
def __init__(
self,
tokenizer,
dataset,
dataset_text_field=None,
formatting_func=None,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
chars_per_token=3.6,
eos_token_id=0,
shuffle=True,
append_concat_token=True,
add_special_tokens=True,
):
warnings.warn(
"This class is deprecated and will be removed in version 0.20.0. To use packing, use the argument "
"`packing` of `SFTConfig` instead.",
DeprecationWarning,
)
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.shuffle = shuffle
self.append_concat_token = append_concat_token
self.add_special_tokens = add_special_tokens
if dataset_text_field is not None and formatting_func is not None:
warnings.warn(
"Only one of `dataset_text_field` and `formatting_func` should be provided. "
"Ignoring `dataset_text_field` and using `formatting_func`.",
UserWarning,
)
if formatting_func is not None:
self.formatting_func = formatting_func
elif dataset_text_field is not None:
self.formatting_func = lambda x: x[dataset_text_field]
else: # neither is provided
raise ValueError("Either `dataset_text_field` or `formatting_func` should be provided.")
self.pretokenized = False
column_names = (
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None
)
if column_names is not None and "input_ids" in column_names:
self.pretokenized = True
# since the dataset is tokenized, the unit of buffer size should be tokens
self.max_buffer_size = seq_length * num_of_sequences
def __len__(self):
return len(self.dataset)
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(self.formatting_func(next(iterator)))
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
else:
more_examples = False
break
if self.shuffle:
random.shuffle(buffer)
if self.pretokenized:
tokenized_inputs = buffer
else:
tokenized_inputs = self.tokenizer(
buffer, add_special_tokens=self.add_special_tokens, truncation=False
)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
if self.append_concat_token:
tokenized_input = tokenized_input + [self.concat_token_id]
all_token_ids.extend(tokenized_input)
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
examples.append(input_ids)
if self.shuffle:
# Shuffle again, otherwise split examples occur in consecutive tensors.
random.shuffle(examples)
for example in examples:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(example),
"labels": torch.LongTensor(example),
}
# TODO: Add Abstract Base Class if more formats are added
@dataclass
class ChatMlSpecialTokens:
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""
bos_token: str = "<|im_start|>"
eos_token: str = "<|im_end|>"
pad_token: str = "<|im_end|>"
@property
def system(self):
return f"{self.bos_token}system"
@property
def user(self):
return f"{self.bos_token}user"
@property
def assistant(self):
return f"{self.bos_token}assistant"
@property
def chat_template(self):
return (
"{% for message in messages %}"
f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
f"{{{{ '{self.assistant}\n' }}}}"
"{% endif %}"
)
FORMAT_MAPPING_CHAT = {"chatml": ChatMlSpecialTokens}
def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Literal["chatml"] | None = "chatml",
resize_to_multiple_of: int | None = None,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
# docstyle-ignore
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the
embedding layer of the model based on the new special tokens.
> [!WARNING]
> This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead.
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set
`tokenizer.chat_template` to `None`.
Args:
model ([`~transformers.PreTrainedModel`]): The model to be modified.
tokenizer ([`~transformers.PreTrainedTokenizer`]): The tokenizer to be modified.
format (`Literal["chatml"] | None`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None.
Returns:
model ([`~transformers.PreTrainedModel`]):
The modified model.
tokenizer ([`~transformers.PreTrainedTokenizer`]):
The modified tokenizer.
"""
warnings.warn(
"The `setup_chat_format` function is deprecated and will be removed in version 0.26.0. Please use "
"`clone_chat_template` instead.",
FutureWarning,
)
# check if model already had a chat template
if tokenizer.chat_template is not None:
raise ValueError(
"Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None"
)
# check if format available and retrieve
if format not in FORMAT_MAPPING_CHAT:
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")
chat_format = FORMAT_MAPPING_CHAT[format]()
# set special tokens and them
tokenizer.eos_token = chat_format.eos_token
tokenizer.pad_token = chat_format.pad_token
tokenizer.bos_token = chat_format.bos_token
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
# set chat format for tokenizer
tokenizer.chat_template = chat_format.chat_template
# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
# as handling of special and added tokens varies across tokenizers.
new_num_tokens=len(tokenizer.vocab),
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
)
# Update the model config to use the new eos & bos tokens
if getattr(model, "config", None) is not None:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# Update the generation config to use the new eos & bos token
if getattr(model, "generation_config", None) is not None:
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
def conversations_formatting_function(
tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: list | None = None
):
r"""
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the
tokenizer apply chat template to the dataset along with the schema of the list of functions in the tools list.
<Deprecated version="0.24.0">
`conversations_formatting_function` is deprecated and will be removed in version 0.27. Please use
`tokenizer.apply_chat_template()` directly instead.
</Deprecated>
"""
warnings.warn(
"`conversations_formatting_function` is deprecated and will be removed in TRL 0.27. "
"Please use `tokenizer.apply_chat_template()` directly instead.",
FutureWarning,
stacklevel=2,
)
def format_dataset(examples):
if isinstance(examples[messages_field][0], list):
output_texts = []
for i in range(len(examples[messages_field])):
output_texts.append(
tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False, tools=tools)
)
return output_texts
else:
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False, tools=tools)
return format_dataset
def instructions_formatting_function(tokenizer: AutoTokenizer):
r"""
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the
tokenizer apply chat template to the dataset
<Deprecated version="0.24.0">
`instructions_formatting_function` is deprecated and will be removed in version 0.27. Please use
`tokenizer.apply_chat_template()` directly instead.
</Deprecated>
"""
warnings.warn(
"`instructions_formatting_function` is deprecated and will be removed in TRL 0.27. "
"Please use `tokenizer.apply_chat_template()` directly instead.",
FutureWarning,
stacklevel=2,
)
def format_dataset(examples):
if isinstance(examples["prompt"], list):
output_texts = []
for i in range(len(examples["prompt"])):
converted_sample = [
{"role": "user", "content": examples["prompt"][i]},
{"role": "assistant", "content": examples["completion"][i]},
]
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
return output_texts
else:
converted_sample = [
{"role": "user", "content": examples["prompt"]},
{"role": "assistant", "content": examples["completion"]},
]
return tokenizer.apply_chat_template(converted_sample, tokenize=False)
return format_dataset