Skip to content

Commit ca6f273

Browse files
committed
merge
2 parents 8dee82c + 5043fb6 commit ca6f273

3 files changed

Lines changed: 13 additions & 17 deletions

File tree

textattack/datasets/helpers/ted_multi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TedMultiTranslationDataset(HuggingFaceDataset):
2020
dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/
2121
"""
2222

23-
def __init__(self, source_lang="en", target_lang="de", split="test"):
23+
def __init__(self, source_lang="en", target_lang="de", split="test", shuffle=False):
2424
self._dataset = datasets.load_dataset("ted_multi")[split]
2525
self.examples = self._dataset["translations"]
2626
language_options = set(self.examples[0]["language"])
@@ -34,6 +34,9 @@ def __init__(self, source_lang="en", target_lang="de", split="test"):
3434
)
3535
self.source_lang = source_lang
3636
self.target_lang = target_lang
37+
self.shuffled = shuffle
38+
if shuffle:
39+
self._dataset.shuffle()
3740

3841
def _format_raw_example(self, raw_example):
3942
translations = np.array(raw_example["translation"])

textattack/shared/attacked_text.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ def __eq__(self, other: AttackedText) -> bool:
8282
and it's actually much faster (cache-wise) to just compare
8383
by the text, and this works for lots of use cases.
8484
"""
85-
return self.text == other.text
85+
if not (self.text == other.text):
86+
return False
87+
if len(self.attack_attrs) != len(other.attack_attrs):
88+
return False
89+
return True
8690

8791
def __hash__(self) -> int:
8892
return hash(self.text)
@@ -466,9 +470,6 @@ def generate_new_attacked_text(self, new_words: Iterable[str]) -> AttackedText:
466470
perturbed_text += adv_word_seq
467471
perturbed_text += original_text # Add all of the ending punctuation.
468472

469-
# Add pointer to self so chain of replacements can be reconstructed.
470-
new_attack_attrs["prev_attacked_text"] = self
471-
472473
# Reform perturbed_text into an OrderedDict.
473474
perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN)
474475
perturbed_input = OrderedDict(
@@ -570,7 +571,10 @@ def num_words(self) -> int:
570571

571572
@property
572573
def newly_swapped_words(self) -> List[str]:
573-
return [self.words[i] for i in self.attack_attrs["newly_modified_indices"]]
574+
return [
575+
self.attack_attrs["prev_attacked_text"].words[i]
576+
for i in self.attack_attrs["newly_modified_indices"]
577+
]
574578

575579
def printable_text(self, key_color="bold", key_color_method=None) -> str:
576580
"""Represents full text input. Adds field descriptions.

textattack/shared/utils/strings.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import string
33

44
import flair
5-
import jieba
6-
import pycld2 as cld2
75

86
from .importing import LazyLoader
97

@@ -32,15 +30,6 @@ def add_indent(s_, numSpaces):
3230
def words_from_text(s, words_to_ignore=[]):
3331
"""Lowercases a string, removes all non-alphanumeric characters, and splits
3432
into words."""
35-
# try:
36-
# isReliable, textBytesFound, details = cld2.detect(s)
37-
# if details[0][0] == "Chinese" or details[0][0] == "ChineseT":
38-
# seg_list = jieba.cut(s, cut_all=False)
39-
# s = " ".join(seg_list)
40-
# else:
41-
# s = " ".join(s.split())
42-
# except Exception:
43-
# s = " ".join(s.split())
4433
s = " ".join(s.split())
4534

4635
homos = """˗৭Ȣ𝟕бƼᏎƷᒿlO`ɑЬϲԁе𝚏ɡհіϳ𝒌ⅼmոорԛⲅѕ𝚝սѵԝ×уᴢ"""

0 commit comments

Comments
 (0)