Skip to content

Commit bb56f61

Browse files
authored
Merge pull request #701 from QData/oct-bug-fixes
Fix bugs with t5
2 parents 5ac125a + 6e60ae6 commit bb56f61

3 files changed

Lines changed: 17 additions & 6 deletions

File tree

textattack/datasets/helpers/ted_multi.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,19 @@ def __init__(self, source_lang="en", target_lang="de", split="test", shuffle=Fal
3535
self.source_lang = source_lang
3636
self.target_lang = target_lang
3737
self.shuffled = shuffle
38+
self.label_map = None
39+
self.output_scale_factor = None
40+
self.label_names = None
41+
# self.input_columns = ("Source",)
42+
# self.output_column = "Translation"
43+
3844
if shuffle:
3945
self._dataset.shuffle()
4046

41-
def _format_raw_example(self, raw_example):
42-
translations = np.array(raw_example["translation"])
43-
languages = np.array(raw_example["language"])
47+
def _format_as_dict(self, raw_example):
48+
example = raw_example["translations"]
49+
translations = np.array(example["translation"])
50+
languages = np.array(example["language"])
4451
source = translations[languages == self.source_lang][0]
4552
target = translations[languages == self.target_lang][0]
4653
source_dict = collections.OrderedDict([("Source", source)])

textattack/goal_functions/text/text_to_text_goal_function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
-------------------------------------------------------
55
"""
66

7+
import numpy as np
78

89
from textattack.goal_function_results import TextToTextGoalFunctionResult
910
from textattack.goal_functions import GoalFunction
@@ -22,7 +23,10 @@ def _goal_function_result_type(self):
2223

2324
def _process_model_outputs(self, _, outputs):
2425
"""Processes and validates a list of model outputs."""
25-
return outputs.flatten()
26+
if isinstance(outputs, np.ndarray):
27+
return outputs.flatten()
28+
else:
29+
return outputs
2630

2731
def _get_displayed_output(self, raw_output):
2832
return raw_output

textattack/models/tokenizers/t5_tokenizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, mode="english_to_german", max_length=64):
3838
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
3939
"t5-base", use_fast=True
4040
)
41-
self.max_length = max_length
41+
self.model_max_length = max_length
4242

4343
def __call__(self, text, *args, **kwargs):
4444
"""
@@ -55,7 +55,7 @@ def __call__(self, text, *args, **kwargs):
5555
else:
5656
for i in range(len(text)):
5757
text[i] = self.tokenization_prefix + text[i]
58-
return self.tokenizer(text, *args, max_length=self.max_length, **kwargs)
58+
return self.tokenizer(text, *args, **kwargs)
5959

6060
def decode(self, ids):
6161
"""Converts IDs (typically generated by the model) back to a string."""

0 commit comments

Comments
 (0)