Skip to content

Commit 3681f9d

Browse files
authored
Merge pull request #700 from QData/oct-bug-fixes
fix bugs in AT, strings (Chinese), TedMulti dataset
2 parents fea5cb2 + 4179531 commit 3681f9d

3 files changed

Lines changed: 9 additions & 28 deletions

File tree

textattack/datasets/helpers/ted_multi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TedMultiTranslationDataset(HuggingFaceDataset):
2020
dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/
2121
"""
2222

23-
def __init__(self, source_lang="en", target_lang="de", split="test"):
23+
def __init__(self, source_lang="en", target_lang="de", split="test", shuffle=False):
2424
self._dataset = datasets.load_dataset("ted_multi")[split]
2525
self.examples = self._dataset["translations"]
2626
language_options = set(self.examples[0]["language"])
@@ -34,6 +34,9 @@ def __init__(self, source_lang="en", target_lang="de", split="test"):
3434
)
3535
self.source_lang = source_lang
3636
self.target_lang = target_lang
37+
self.shuffled = shuffle
38+
if shuffle:
39+
self._dataset.shuffle()
3740

3841
def _format_raw_example(self, raw_example):
3942
translations = np.array(raw_example["translation"])

textattack/shared/attacked_text.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,21 +82,6 @@ def __eq__(self, other):
8282
return False
8383
if len(self.attack_attrs) != len(other.attack_attrs):
8484
return False
85-
for key in self.attack_attrs:
86-
if key not in other.attack_attrs:
87-
return False
88-
elif isinstance(self.attack_attrs[key], np.ndarray):
89-
if not (self.attack_attrs[key].shape == other.attack_attrs[key].shape):
90-
return False
91-
elif not (self.attack_attrs[key] == other.attack_attrs[key]).all():
92-
return False
93-
else:
94-
if isinstance(self.attack_attrs[key], AttackedText):
95-
if (
96-
not self.attack_attrs[key]._text_input
97-
== other.attack_attrs[key]._text_input
98-
):
99-
return False
10085
return True
10186

10287
def __hash__(self):
@@ -576,7 +561,10 @@ def num_words(self):
576561

577562
@property
578563
def newly_swapped_words(self):
579-
return [self.words[i] for i in self.attack_attrs["newly_modified_indices"]]
564+
return [
565+
self.attack_attrs["prev_attacked_text"].words[i]
566+
for i in self.attack_attrs["newly_modified_indices"]
567+
]
580568

581569
def printable_text(self, key_color="bold", key_color_method=None):
582570
"""Represents full text input. Adds field descriptions.

textattack/shared/utils/strings.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import string
33

44
import flair
5-
import jieba
6-
import pycld2 as cld2
75

86
from .importing import LazyLoader
97

@@ -32,15 +30,7 @@ def add_indent(s_, numSpaces):
3230
def words_from_text(s, words_to_ignore=[]):
3331
"""Lowercases a string, removes all non-alphanumeric characters, and splits
3432
into words."""
35-
try:
36-
isReliable, textBytesFound, details = cld2.detect(s)
37-
if details[0][0] == "Chinese" or details[0][0] == "ChineseT":
38-
seg_list = jieba.cut(s, cut_all=False)
39-
s = " ".join(seg_list)
40-
else:
41-
s = " ".join(s.split())
42-
except Exception:
43-
s = " ".join(s.split())
33+
s = " ".join(s.split())
4434

4535
homos = """˗৭Ȣ𝟕бƼᏎƷᒿlO`ɑЬϲԁе𝚏ɡհіϳ𝒌ⅼmոорԛⲅѕ𝚝սѵԝ×уᴢ"""
4636
exceptions = """'-_*@"""

0 commit comments

Comments
 (0)