Skip to content

Commit c854702

Browse files
committed
upgrade support to trl 0.27+
Signed-off-by: Yash Mehan <yashmehan@gmail.com>
1 parent fd3b977 commit c854702

4 files changed

Lines changed: 248 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"sentencepiece>=0.1.99,<0.3",
3535
"tokenizers<=0.23.0",
3636
"tqdm>=4.66.2,<5.0",
37-
"trl>=0.19.1,<0.20.0",
37+
"trl>=0.27.0,<0.29.0",
3838
"peft>=0.18.1,<0.19.0",
3939
"datasets>=4.0.0,<5.0.0",
4040
"simpleeval>=0.9.13,<2.0",

tests/data/test_data_preprocessing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from datasets import Dataset, DatasetDict, IterableDataset
2323
from PIL import Image
2424
from transformers import AutoProcessor, AutoTokenizer, DataCollatorForSeq2Seq
25-
from trl import DataCollatorForCompletionOnlyLM
2625
import datasets
2726
import numpy as np
2827
import pyarrow
@@ -69,7 +68,7 @@
6968
# Local
7069
from tuning.config import configs
7170
from tuning.config.acceleration_configs import AttentionAndDistributedPackingConfig
72-
from tuning.data.collators import VisionDataCollator
71+
from tuning.data.collators import DataCollatorForCompletionOnlyLM, VisionDataCollator
7372
from tuning.data.data_config import (
7473
DataHandlerConfig,
7574
DataPreProcessorConfig,

tuning/data/collators.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Standard
16+
from typing import Any, Optional, Union
17+
import logging
18+
19+
# Third Party
20+
from transformers import DataCollatorForLanguageModeling
21+
import numpy as np
22+
import torch
23+
1524
# Local
1625
from tuning.data.utils import try_convert_bytes_dict_to_pil
1726

@@ -91,3 +100,239 @@ def __call__(self, features):
91100
batch["labels"] = labels
92101

93102
return batch
103+
104+
105+
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
106+
"""
107+
Data collator used for completion tasks.
108+
It ensures that all the tokens of the labels
109+
are set to an 'ignore_index'
110+
when they do not come from the assistant.
111+
This ensure that the loss is only calculated on the completion made by
112+
the assistant.
113+
114+
Args:
115+
response_template (`Union[str, list[int]]`):
116+
the template form that indicates the
117+
start of the response, typically
118+
something like '### Response:\n'. It
119+
can also be passed as tokenized ids,
120+
which can be useful when using a tokenizer
121+
that encodes the response
122+
differently if it does not have proper context.
123+
instruction_template (`Union[str, list[int]]`):
124+
the template form that indicates the start
125+
of the human instruction, typically
126+
something like '###
127+
Human:\n'. Useful for assistant-style
128+
conversation datasets. It can also be passed
129+
as tokenized ids.
130+
mlm (`bool`, *optional*, defaults to `False`): Whether
131+
to use masked language modeling in the underlying
132+
`DataCollatorForLanguageModeling` class.
133+
Note that this option currently has no effect but is present
134+
for flexibility and backwards-compatibility.
135+
ignore_index (`int`, *optional*, defaults to `-100`):
136+
The index to use to ignore the initial tokens with
137+
"""
138+
139+
def __init__(
140+
self,
141+
*args,
142+
response_template: Union[str, list[int]],
143+
instruction_template: Optional[Union[str, list[int]]] = None,
144+
mlm: bool = False,
145+
ignore_index: int = -100,
146+
padding_free: bool = False,
147+
**kwargs,
148+
):
149+
super().__init__(*args, mlm=mlm, **kwargs)
150+
151+
self.instruction_template = instruction_template
152+
if isinstance(instruction_template, str):
153+
# The user provides a string, must tokenize
154+
self.instruction_token_ids = self.tokenizer.encode(
155+
self.instruction_template, add_special_tokens=False
156+
)
157+
else:
158+
# The user already provides the token ids
159+
self.instruction_token_ids = instruction_template
160+
161+
self.response_template = response_template
162+
if isinstance(response_template, str):
163+
# The user provides a string, must tokenize
164+
self.response_token_ids = self.tokenizer.encode(
165+
self.response_template, add_special_tokens=False
166+
)
167+
else:
168+
# The user already provides the token ids
169+
self.response_token_ids = response_template
170+
171+
if (
172+
not self.mlm
173+
and self.instruction_template
174+
and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
175+
):
176+
logging.warning(
177+
"The pad_token_id and eos_token_id values "
178+
"of this tokenizer are identical. "
179+
"If you are planning for multi-turn training, "
180+
"it can result in the model continuously generating "
181+
"questions and answers without eos token. "
182+
"To avoid this, set the pad_token_id to a different value.",
183+
)
184+
185+
self.ignore_index = ignore_index
186+
self.padding_free = padding_free
187+
188+
def torch_call(
189+
self, examples: list[Union[list[int], Any, dict[str, Any]]]
190+
) -> dict[str, Any]:
191+
batch = super().torch_call(examples)
192+
193+
if self.instruction_template is None:
194+
for i in range(len(examples)):
195+
response_token_ids_start_idx = None
196+
197+
for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[
198+
0
199+
]:
200+
# `response_token_ids` is
201+
# `'### Response:\n'`, here we are just making sure
202+
# that the token IDs match
203+
if (
204+
self.response_token_ids
205+
== batch["labels"][i][
206+
idx : idx + len(self.response_token_ids)
207+
].tolist()
208+
):
209+
response_token_ids_start_idx = idx
210+
211+
if response_token_ids_start_idx is None:
212+
logging.warning(
213+
"Could not find response key %s in the following instance: "
214+
"%s. This instance will be ignored in loss "
215+
"calculation. Note, if this happens often, "
216+
"consider increasing the `max_length`.",
217+
self.response_template,
218+
self.tokenizer.decode(batch["input_ids"][i]),
219+
)
220+
batch["labels"][i, :] = self.ignore_index
221+
else:
222+
response_token_ids_end_idx = response_token_ids_start_idx + len(
223+
self.response_token_ids
224+
)
225+
226+
# Make pytorch loss function ignore all
227+
# tokens up through the end of the response key
228+
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
229+
230+
else:
231+
for i in range(len(examples)):
232+
response_token_ids_idxs = []
233+
human_token_ids_idxs = []
234+
235+
for assistant_idx in np.where(
236+
batch["labels"][i] == self.response_token_ids[0]
237+
)[0]:
238+
# find the indexes of the start of a response.
239+
if (
240+
self.response_token_ids
241+
== batch["labels"][i][
242+
assistant_idx : assistant_idx + len(self.response_token_ids)
243+
].tolist()
244+
):
245+
response_token_ids_idxs.append(
246+
assistant_idx + len(self.response_token_ids)
247+
)
248+
249+
if len(response_token_ids_idxs) == 0:
250+
logging.warning(
251+
"Could not find response key %s in the following instance: "
252+
"%s. This instance will be ignored in loss "
253+
"calculation. Note, if this happens often, "
254+
"consider increasing the `max_length`.",
255+
self.response_template,
256+
self.tokenizer.decode(batch["input_ids"][i]),
257+
)
258+
batch["labels"][i, :] = self.ignore_index
259+
260+
human_token_ids = self.instruction_token_ids
261+
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
262+
# find the indexes of the start of a human answer.
263+
if (
264+
human_token_ids
265+
== batch["labels"][i][
266+
human_idx : human_idx + len(human_token_ids)
267+
].tolist()
268+
):
269+
human_token_ids_idxs.append(human_idx)
270+
271+
if len(human_token_ids_idxs) == 0:
272+
logging.warning(
273+
"Could not find instruction key `%s` in the following instance: "
274+
"%s. This instance will be ignored in loss "
275+
"calculation. Note, if this happens often, "
276+
"consider increasing the `max_length`.",
277+
self.instruction_template,
278+
self.tokenizer.decode(batch["input_ids"][i]),
279+
)
280+
batch["labels"][i, :] = self.ignore_index
281+
282+
if (
283+
len(human_token_ids_idxs) > 0
284+
and len(response_token_ids_idxs) > 0
285+
and human_token_ids_idxs[0] > response_token_ids_idxs[0]
286+
):
287+
human_token_ids_idxs = [0] + human_token_ids_idxs
288+
289+
for idx, (start, end) in enumerate(
290+
zip(human_token_ids_idxs, response_token_ids_idxs)
291+
):
292+
# Make pytorch loss function ignore all non response tokens
293+
if idx != 0:
294+
batch["labels"][i, start:end] = self.ignore_index
295+
else:
296+
batch["labels"][i, :end] = self.ignore_index
297+
298+
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
299+
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
300+
301+
if self.padding_free:
302+
# remove padding, `attention_mask` and add `position_ids`
303+
attn_mask = batch.pop("attention_mask")
304+
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
305+
batch["position_ids"] = (
306+
attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
307+
)
308+
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
309+
batch["labels"][batch["position_ids"] == 0] = self.ignore_index
310+
311+
# Calculate cumulative sequence lengths for queries and
312+
# keys to prevent graph breaks during further computations.
313+
flattened_position_ids = batch["position_ids"].flatten()
314+
indices_q = torch.arange(
315+
flattened_position_ids.size(0),
316+
device=flattened_position_ids.device,
317+
dtype=torch.int32,
318+
)
319+
batch["cu_seq_lens_q"] = torch.cat(
320+
(
321+
indices_q[flattened_position_ids == 0],
322+
torch.tensor(
323+
flattened_position_ids.size(),
324+
device=flattened_position_ids.device,
325+
dtype=torch.int32,
326+
),
327+
)
328+
).unsqueeze(0)
329+
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]
330+
331+
# Determine maximum sequence lengths to
332+
# prevent graph breaks during further computations.
333+
batch["max_length_k"] = torch.tensor(
334+
[flattened_position_ids.max().item() + 1]
335+
)
336+
batch["max_length_q"] = batch["max_length_k"]
337+
338+
return batch

tuning/data/data_preprocessing_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222
DataCollatorForSeq2Seq,
2323
LlavaProcessor,
2424
)
25-
from trl import DataCollatorForCompletionOnlyLM
2625

2726
# Local
2827
from tuning.config import configs
29-
from tuning.data.collators import VisionDataCollator
28+
from tuning.data.collators import DataCollatorForCompletionOnlyLM, VisionDataCollator
3029

3130
logger = logging.getLogger(__name__)
3231

0 commit comments

Comments
 (0)