Skip to content

Commit 8dee82c

Browse files
committed
add typing for AttackedText
1 parent 899ea46 commit 8dee82c

2 files changed

Lines changed: 72 additions & 73 deletions

File tree

textattack/shared/attacked_text.py

Lines changed: 62 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
A helper class that represents a string that can be attacked.
77
"""
88

9+
from __future__ import annotations
10+
911
from collections import OrderedDict
1012
import math
13+
from typing import Iterable, List, Optional, Set, Tuple
1114

1215
import flair
1316
from flair.data import Sentence
@@ -71,31 +74,17 @@ def __init__(self, text_input, attack_attrs=None):
7174
# A list of all indices in *this* text that have been modified.
7275
self.attack_attrs.setdefault("modified_indices", set())
7376

74-
def __eq__(self, other):
75-
"""Compares two text instances to make sure they have the same attack
76-
attributes.
77+
def __eq__(self, other: AttackedText) -> bool:
78+
"""Compares two AttackedText instances.
7779
78-
Since some elements stored in ``self.attack_attrs`` may be numpy
79-
arrays, we have to take special care when comparing them.
80+
Note: Does not compute true equality across attack attributes.
81+
We found this caused large performance issues with caching,
82+
and it's actually much faster (cache-wise) to just compare
83+
by the text, and this works for lots of use cases.
8084
"""
81-
if not (self.text == other.text):
82-
return False
83-
if len(self.attack_attrs) != len(other.attack_attrs):
84-
return False
85-
for key in self.attack_attrs:
86-
if key not in other.attack_attrs:
87-
return False
88-
elif isinstance(self.attack_attrs[key], np.ndarray):
89-
if not (self.attack_attrs[key].shape == other.attack_attrs[key].shape):
90-
return False
91-
elif not (self.attack_attrs[key] == other.attack_attrs[key]).all():
92-
return False
93-
else:
94-
if not self.attack_attrs[key] == other.attack_attrs[key]:
95-
return False
96-
return True
85+
return self.text == other.text
9786

98-
def __hash__(self):
87+
def __hash__(self) -> int:
9988
return hash(self.text)
10089

10190
def free_memory(self):
@@ -113,7 +102,7 @@ def free_memory(self):
113102
if isinstance(self.attack_attrs[key], torch.Tensor):
114103
self.attack_attrs.pop(key, None)
115104

116-
def text_window_around_index(self, index, window_size):
105+
def text_window_around_index(self, index: int, window_size: int) -> str:
117106
"""The text window of ``window_size`` words centered around
118107
``index``."""
119108
length = self.num_words
@@ -131,10 +120,12 @@ def text_window_around_index(self, index, window_size):
131120
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
132121
return self.text[text_idx_start:text_idx_end]
133122

134-
def pos_of_word_index(self, desired_word_idx):
123+
def pos_of_word_index(self, desired_word_idx: int) -> str:
135124
"""Returns the part-of-speech of the word at index `word_idx`.
136125
137126
Uses FLAIR part-of-speech tagger.
127+
128+
Throws: ValueError, if no POS tag found for index.
138129
"""
139130
if not self._pos_tags:
140131
sentence = Sentence(
@@ -162,10 +153,12 @@ def pos_of_word_index(self, desired_word_idx):
162153
f"Did not find word from index {desired_word_idx} in flair POS tag"
163154
)
164155

165-
def ner_of_word_index(self, desired_word_idx, model_name="ner"):
156+
def ner_of_word_index(self, desired_word_idx: int, model_name="ner") -> str:
166157
"""Returns the ner tag of the word at index `word_idx`.
167158
168159
Uses FLAIR ner tagger.
160+
161+
Throws: ValueError, if not NER tag found for index.
169162
"""
170163
if not self._ner_tags:
171164
sentence = Sentence(
@@ -190,7 +183,7 @@ def ner_of_word_index(self, desired_word_idx, model_name="ner"):
190183
f"Did not find word from index {desired_word_idx} in flair POS tag"
191184
)
192185

193-
def _text_index_of_word_index(self, i):
186+
def _text_index_of_word_index(self, i: int) -> int:
194187
"""Returns the index of word ``i`` in self.text."""
195188
pre_words = self.words[: i + 1]
196189
lower_text = self.text.lower()
@@ -203,20 +196,20 @@ def _text_index_of_word_index(self, i):
203196
look_after_index -= len(self.words[i])
204197
return look_after_index
205198

206-
def text_until_word_index(self, i):
199+
def text_until_word_index(self, i: int) -> str:
207200
"""Returns the text before the beginning of word at index ``i``."""
208201
look_after_index = self._text_index_of_word_index(i)
209202
return self.text[:look_after_index]
210203

211-
def text_after_word_index(self, i):
204+
def text_after_word_index(self, i: int) -> str:
212205
"""Returns the text after the end of word at index ``i``."""
213206
# Get index of beginning of word then jump to end of word.
214207
look_after_index = self._text_index_of_word_index(i) + len(self.words[i])
215208
return self.text[look_after_index:]
216209

217-
def first_word_diff(self, other_attacked_text):
210+
def first_word_diff(self, other_attacked_text: AttackedText) -> Optional[str]:
218211
"""Returns the first word in self.words that differs from
219-
other_attacked_text.
212+
other_attacked_text, or None if all words are the same.
220213
221214
Useful for word swap strategies.
222215
"""
@@ -227,7 +220,7 @@ def first_word_diff(self, other_attacked_text):
227220
return w1[i]
228221
return None
229222

230-
def first_word_diff_index(self, other_attacked_text):
223+
def first_word_diff_index(self, other_attacked_text: AttackedText) -> Optional[int]:
231224
"""Returns the index of the first word in self.words that differs from
232225
other_attacked_text.
233226
@@ -240,7 +233,7 @@ def first_word_diff_index(self, other_attacked_text):
240233
return i
241234
return None
242235

243-
def all_words_diff(self, other_attacked_text):
236+
def all_words_diff(self, other_attacked_text: AttackedText) -> Set[int]:
244237
"""Returns the set of indices for which this and other_attacked_text
245238
have different words."""
246239
indices = set()
@@ -251,16 +244,17 @@ def all_words_diff(self, other_attacked_text):
251244
indices.add(i)
252245
return indices
253246

254-
def ith_word_diff(self, other_attacked_text, i):
255-
"""Returns whether the word at index i differs from
247+
def ith_word_diff(self, other_attacked_text: AttackedText, i: int) -> bool:
248+
"""Returns bool representing whether the word at index i differs from
256249
other_attacked_text."""
257250
w1 = self.words
258251
w2 = other_attacked_text.words
259252
if len(w1) - 1 < i or len(w2) - 1 < i:
260253
return True
261254
return w1[i] != w2[i]
262255

263-
def words_diff_num(self, other_attacked_text):
256+
def words_diff_num(self, other_attacked_text: AttackedText) -> int:
257+
"""The number of words different between two AttackedText objects."""
264258
# using edit distance to calculate words diff num
265259
def generate_tokens(words):
266260
result = {}
@@ -306,7 +300,7 @@ def cal_dif(w1, w2):
306300
w2 = other_attacked_text.words
307301
return cal_dif(w1, w2)
308302

309-
def convert_from_original_idxs(self, idxs):
303+
def convert_from_original_idxs(self, idxs: Iterable[int]) -> List[int]:
310304
"""Takes indices of words from original string and converts them to
311305
indices of the same words in the current string.
312306
@@ -326,9 +320,16 @@ def convert_from_original_idxs(self, idxs):
326320

327321
return [self.attack_attrs["original_index_map"][i] for i in idxs]
328322

329-
def replace_words_at_indices(self, indices, new_words):
330-
"""This code returns a new AttackedText object where the word at
331-
``index`` is replaced with a new word."""
323+
def get_deletion_indices(self) -> Iterable[int]:
324+
return self.attack_attrs["original_index_map"][
325+
self.attack_attrs["original_index_map"] == -1
326+
]
327+
328+
def replace_words_at_indices(
329+
self, indices: Iterable[int], new_words: Iterable[str]
330+
) -> AttackedText:
331+
"""Returns a new AttackedText object where the word at ``index`` is
332+
replaced with a new word."""
332333
if len(indices) != len(new_words):
333334
raise ValueError(
334335
f"Cannot replace {len(new_words)} words at {len(indices)} indices."
@@ -344,21 +345,21 @@ def replace_words_at_indices(self, indices, new_words):
344345
words[i] = new_word
345346
return self.generate_new_attacked_text(words)
346347

347-
def replace_word_at_index(self, index, new_word):
348-
"""This code returns a new AttackedText object where the word at
349-
``index`` is replaced with a new word."""
348+
def replace_word_at_index(self, index: int, new_word: str) -> AttackedText:
349+
"""Returns a new AttackedText object where the word at ``index`` is
350+
replaced with a new word."""
350351
if not isinstance(new_word, str):
351352
raise TypeError(
352353
f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}"
353354
)
354355
return self.replace_words_at_indices([index], [new_word])
355356

356-
def delete_word_at_index(self, index):
357-
"""This code returns a new AttackedText object where the word at
358-
``index`` is removed."""
357+
def delete_word_at_index(self, index: int) -> AttackedText:
358+
"""Returns a new AttackedText object where the word at ``index`` is
359+
removed."""
359360
return self.replace_word_at_index(index, "")
360361

361-
def insert_text_after_word_index(self, index, text):
362+
def insert_text_after_word_index(self, index: int, text: str) -> AttackedText:
362363
"""Inserts a string before word at index ``index`` and attempts to add
363364
appropriate spacing."""
364365
if not isinstance(text, str):
@@ -367,7 +368,7 @@ def insert_text_after_word_index(self, index, text):
367368
new_text = " ".join((word_at_index, text))
368369
return self.replace_word_at_index(index, new_text)
369370

370-
def insert_text_before_word_index(self, index, text):
371+
def insert_text_before_word_index(self, index: int, text: str) -> AttackedText:
371372
"""Inserts a string before word at index ``index`` and attempts to add
372373
appropriate spacing."""
373374
if not isinstance(text, str):
@@ -378,12 +379,7 @@ def insert_text_before_word_index(self, index, text):
378379
new_text = " ".join((text, word_at_index))
379380
return self.replace_word_at_index(index, new_text)
380381

381-
def get_deletion_indices(self):
382-
return self.attack_attrs["original_index_map"][
383-
self.attack_attrs["original_index_map"] == -1
384-
]
385-
386-
def generate_new_attacked_text(self, new_words):
382+
def generate_new_attacked_text(self, new_words: Iterable[str]) -> AttackedText:
387383
"""Returns a new AttackedText object and replaces old list of words
388384
with a new list of words, but preserves the punctuation and spacing of
389385
the original message.
@@ -480,15 +476,17 @@ def generate_new_attacked_text(self, new_words):
480476
)
481477
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
482478

483-
def words_diff_ratio(self, x):
479+
def words_diff_ratio(self, x: AttackedText) -> float:
484480
"""Get the ratio of words difference between current text and `x`.
485481
486482
Note that current text and `x` must have same number of words.
487483
"""
488484
assert self.num_words == x.num_words
489485
return float(np.sum(self.words != x.words)) / self.num_words
490486

491-
def align_with_model_tokens(self, model_wrapper):
487+
def align_with_model_tokens(
488+
self, model_wrapper: textattack.models.wrappers.ModelWrapper
489+
) -> Dict[int, Iterable[int]]:
492490
"""Align AttackedText's `words` with target model's tokenization scheme
493491
(e.g. word, character, subword). Specifically, we map each word to list
494492
of indices of tokens that compose the word (e.g. embedding --> ["em",
@@ -525,7 +523,7 @@ def align_with_model_tokens(self, model_wrapper):
525523
return word2token_mapping
526524

527525
@property
528-
def tokenizer_input(self):
526+
def tokenizer_input(self) -> Tuple[str]:
529527
"""The tuple of inputs to be passed to the tokenizer."""
530528
input_tuple = tuple(self._text_input.values())
531529
# Prefer to return a string instead of a tuple with a single value.
@@ -535,15 +533,15 @@ def tokenizer_input(self):
535533
return input_tuple
536534

537535
@property
538-
def column_labels(self):
536+
def column_labels(self) -> List[str]:
539537
"""Returns the labels for this text's columns.
540538
541539
For single-sequence inputs, this simply returns ['text'].
542540
"""
543541
return list(self._text_input.keys())
544542

545543
@property
546-
def words_per_input(self):
544+
def words_per_input(self) -> List[List[str]]:
547545
"""Returns a list of lists of words corresponding to each input."""
548546
if not self._words_per_input:
549547
self._words_per_input = [
@@ -552,29 +550,29 @@ def words_per_input(self):
552550
return self._words_per_input
553551

554552
@property
555-
def words(self):
553+
def words(self) -> List[str]:
556554
if not self._words:
557555
self._words = words_from_text(self.text)
558556
return self._words
559557

560558
@property
561-
def text(self):
559+
def text(self) -> str:
562560
"""Represents full text input.
563561
564562
Multiply inputs are joined with a line break.
565563
"""
566564
return "\n".join(self._text_input.values())
567565

568566
@property
569-
def num_words(self):
567+
def num_words(self) -> int:
570568
"""Returns the number of words in the sequence."""
571569
return len(self.words)
572570

573571
@property
574-
def newly_swapped_words(self):
572+
def newly_swapped_words(self) -> List[str]:
575573
return [self.words[i] for i in self.attack_attrs["newly_modified_indices"]]
576574

577-
def printable_text(self, key_color="bold", key_color_method=None):
575+
def printable_text(self, key_color="bold", key_color_method=None) -> str:
578576
"""Represents full text input. Adds field descriptions.
579577
580578
For example, entailment inputs look like:
@@ -606,5 +604,5 @@ def ck(k):
606604
for key, value in self._text_input.items()
607605
)
608606

609-
def __repr__(self):
607+
def __repr__(self) -> str:
610608
return f'<AttackedText "{self.text}">'

textattack/shared/utils/strings.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,16 @@ def add_indent(s_, numSpaces):
3232
def words_from_text(s, words_to_ignore=[]):
3333
"""Lowercases a string, removes all non-alphanumeric characters, and splits
3434
into words."""
35-
try:
36-
isReliable, textBytesFound, details = cld2.detect(s)
37-
if details[0][0] == "Chinese" or details[0][0] == "ChineseT":
38-
seg_list = jieba.cut(s, cut_all=False)
39-
s = " ".join(seg_list)
40-
else:
41-
s = " ".join(s.split())
42-
except Exception:
43-
s = " ".join(s.split())
35+
# try:
36+
# isReliable, textBytesFound, details = cld2.detect(s)
37+
# if details[0][0] == "Chinese" or details[0][0] == "ChineseT":
38+
# seg_list = jieba.cut(s, cut_all=False)
39+
# s = " ".join(seg_list)
40+
# else:
41+
# s = " ".join(s.split())
42+
# except Exception:
43+
# s = " ".join(s.split())
44+
s = " ".join(s.split())
4445

4546
homos = """˗৭Ȣ𝟕бƼᏎƷᒿlO`ɑЬϲԁе𝚏ɡհіϳ𝒌ⅼmոорԛⲅѕ𝚝սѵԝ×уᴢ"""
4647
exceptions = """'-_*@"""

0 commit comments

Comments
 (0)