Skip to content

Commit 5ac125a

Browse files
authored
Merge pull request #689 from QData/at-typing-refactor
Add typing to AttackedText class
2 parents 5043fb6 + 86e51cf commit 5ac125a

1 file changed

Lines changed: 61 additions & 48 deletions

File tree

textattack/shared/attacked_text.py

Lines changed: 61 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
A helper class that represents a string that can be attacked.
77
"""
88

9+
from __future__ import annotations
10+
911
from collections import OrderedDict
1012
import math
13+
from typing import Dict, Iterable, List, Optional, Set, Tuple
1114

1215
import flair
1316
from flair.data import Sentence
@@ -71,20 +74,21 @@ def __init__(self, text_input, attack_attrs=None):
7174
# A list of all indices in *this* text that have been modified.
7275
self.attack_attrs.setdefault("modified_indices", set())
7376

74-
def __eq__(self, other):
75-
"""Compares two text instances to make sure they have the same attack
76-
attributes.
77+
def __eq__(self, other: AttackedText) -> bool:
78+
"""Compares two AttackedText instances.
7779
78-
Since some elements stored in ``self.attack_attrs`` may be numpy
79-
arrays, we have to take special care when comparing them.
80+
Note: Does not compute true equality across attack attributes.
81+
We found this caused large performance issues with caching,
82+
and it's actually much faster (cache-wise) to just compare
83+
by the text, and this works for lots of use cases.
8084
"""
8185
if not (self.text == other.text):
8286
return False
8387
if len(self.attack_attrs) != len(other.attack_attrs):
8488
return False
8589
return True
8690

87-
def __hash__(self):
91+
def __hash__(self) -> int:
8892
return hash(self.text)
8993

9094
def free_memory(self):
@@ -102,7 +106,7 @@ def free_memory(self):
102106
if isinstance(self.attack_attrs[key], torch.Tensor):
103107
self.attack_attrs.pop(key, None)
104108

105-
def text_window_around_index(self, index, window_size):
109+
def text_window_around_index(self, index: int, window_size: int) -> str:
106110
"""The text window of ``window_size`` words centered around
107111
``index``."""
108112
length = self.num_words
@@ -120,10 +124,12 @@ def text_window_around_index(self, index, window_size):
120124
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
121125
return self.text[text_idx_start:text_idx_end]
122126

123-
def pos_of_word_index(self, desired_word_idx):
127+
def pos_of_word_index(self, desired_word_idx: int) -> str:
124128
"""Returns the part-of-speech of the word at index `word_idx`.
125129
126130
Uses FLAIR part-of-speech tagger.
131+
132+
Throws: ValueError, if no POS tag found for index.
127133
"""
128134
if not self._pos_tags:
129135
sentence = Sentence(
@@ -151,10 +157,12 @@ def pos_of_word_index(self, desired_word_idx):
151157
f"Did not find word from index {desired_word_idx} in flair POS tag"
152158
)
153159

154-
def ner_of_word_index(self, desired_word_idx, model_name="ner"):
160+
def ner_of_word_index(self, desired_word_idx: int, model_name="ner") -> str:
155161
"""Returns the ner tag of the word at index `word_idx`.
156162
157163
Uses FLAIR ner tagger.
164+
165+
Throws: ValueError, if not NER tag found for index.
158166
"""
159167
if not self._ner_tags:
160168
sentence = Sentence(
@@ -179,7 +187,7 @@ def ner_of_word_index(self, desired_word_idx, model_name="ner"):
179187
f"Did not find word from index {desired_word_idx} in flair POS tag"
180188
)
181189

182-
def _text_index_of_word_index(self, i):
190+
def _text_index_of_word_index(self, i: int) -> int:
183191
"""Returns the index of word ``i`` in self.text."""
184192
pre_words = self.words[: i + 1]
185193
lower_text = self.text.lower()
@@ -192,20 +200,20 @@ def _text_index_of_word_index(self, i):
192200
look_after_index -= len(self.words[i])
193201
return look_after_index
194202

195-
def text_until_word_index(self, i):
203+
def text_until_word_index(self, i: int) -> str:
196204
"""Returns the text before the beginning of word at index ``i``."""
197205
look_after_index = self._text_index_of_word_index(i)
198206
return self.text[:look_after_index]
199207

200-
def text_after_word_index(self, i):
208+
def text_after_word_index(self, i: int) -> str:
201209
"""Returns the text after the end of word at index ``i``."""
202210
# Get index of beginning of word then jump to end of word.
203211
look_after_index = self._text_index_of_word_index(i) + len(self.words[i])
204212
return self.text[look_after_index:]
205213

206-
def first_word_diff(self, other_attacked_text):
214+
def first_word_diff(self, other_attacked_text: AttackedText) -> Optional[str]:
207215
"""Returns the first word in self.words that differs from
208-
other_attacked_text.
216+
other_attacked_text, or None if all words are the same.
209217
210218
Useful for word swap strategies.
211219
"""
@@ -216,7 +224,7 @@ def first_word_diff(self, other_attacked_text):
216224
return w1[i]
217225
return None
218226

219-
def first_word_diff_index(self, other_attacked_text):
227+
def first_word_diff_index(self, other_attacked_text: AttackedText) -> Optional[int]:
220228
"""Returns the index of the first word in self.words that differs from
221229
other_attacked_text.
222230
@@ -229,7 +237,7 @@ def first_word_diff_index(self, other_attacked_text):
229237
return i
230238
return None
231239

232-
def all_words_diff(self, other_attacked_text):
240+
def all_words_diff(self, other_attacked_text: AttackedText) -> Set[int]:
233241
"""Returns the set of indices for which this and other_attacked_text
234242
have different words."""
235243
indices = set()
@@ -240,16 +248,17 @@ def all_words_diff(self, other_attacked_text):
240248
indices.add(i)
241249
return indices
242250

243-
def ith_word_diff(self, other_attacked_text, i):
244-
"""Returns whether the word at index i differs from
251+
def ith_word_diff(self, other_attacked_text: AttackedText, i: int) -> bool:
252+
"""Returns bool representing whether the word at index i differs from
245253
other_attacked_text."""
246254
w1 = self.words
247255
w2 = other_attacked_text.words
248256
if len(w1) - 1 < i or len(w2) - 1 < i:
249257
return True
250258
return w1[i] != w2[i]
251259

252-
def words_diff_num(self, other_attacked_text):
260+
def words_diff_num(self, other_attacked_text: AttackedText) -> int:
261+
"""The number of words different between two AttackedText objects."""
253262
# using edit distance to calculate words diff num
254263
def generate_tokens(words):
255264
result = {}
@@ -295,7 +304,7 @@ def cal_dif(w1, w2):
295304
w2 = other_attacked_text.words
296305
return cal_dif(w1, w2)
297306

298-
def convert_from_original_idxs(self, idxs):
307+
def convert_from_original_idxs(self, idxs: Iterable[int]) -> List[int]:
299308
"""Takes indices of words from original string and converts them to
300309
indices of the same words in the current string.
301310
@@ -315,9 +324,16 @@ def convert_from_original_idxs(self, idxs):
315324

316325
return [self.attack_attrs["original_index_map"][i] for i in idxs]
317326

318-
def replace_words_at_indices(self, indices, new_words):
319-
"""This code returns a new AttackedText object where the word at
320-
``index`` is replaced with a new word."""
327+
def get_deletion_indices(self) -> Iterable[int]:
328+
return self.attack_attrs["original_index_map"][
329+
self.attack_attrs["original_index_map"] == -1
330+
]
331+
332+
def replace_words_at_indices(
333+
self, indices: Iterable[int], new_words: Iterable[str]
334+
) -> AttackedText:
335+
"""Returns a new AttackedText object where the word at ``index`` is
336+
replaced with a new word."""
321337
if len(indices) != len(new_words):
322338
raise ValueError(
323339
f"Cannot replace {len(new_words)} words at {len(indices)} indices."
@@ -333,21 +349,21 @@ def replace_words_at_indices(self, indices, new_words):
333349
words[i] = new_word
334350
return self.generate_new_attacked_text(words)
335351

336-
def replace_word_at_index(self, index, new_word):
337-
"""This code returns a new AttackedText object where the word at
338-
``index`` is replaced with a new word."""
352+
def replace_word_at_index(self, index: int, new_word: str) -> AttackedText:
353+
"""Returns a new AttackedText object where the word at ``index`` is
354+
replaced with a new word."""
339355
if not isinstance(new_word, str):
340356
raise TypeError(
341357
f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}"
342358
)
343359
return self.replace_words_at_indices([index], [new_word])
344360

345-
def delete_word_at_index(self, index):
346-
"""This code returns a new AttackedText object where the word at
347-
``index`` is removed."""
361+
def delete_word_at_index(self, index: int) -> AttackedText:
362+
"""Returns a new AttackedText object where the word at ``index`` is
363+
removed."""
348364
return self.replace_word_at_index(index, "")
349365

350-
def insert_text_after_word_index(self, index, text):
366+
def insert_text_after_word_index(self, index: int, text: str) -> AttackedText:
351367
"""Inserts a string before word at index ``index`` and attempts to add
352368
appropriate spacing."""
353369
if not isinstance(text, str):
@@ -356,7 +372,7 @@ def insert_text_after_word_index(self, index, text):
356372
new_text = " ".join((word_at_index, text))
357373
return self.replace_word_at_index(index, new_text)
358374

359-
def insert_text_before_word_index(self, index, text):
375+
def insert_text_before_word_index(self, index: int, text: str) -> AttackedText:
360376
"""Inserts a string before word at index ``index`` and attempts to add
361377
appropriate spacing."""
362378
if not isinstance(text, str):
@@ -367,12 +383,7 @@ def insert_text_before_word_index(self, index, text):
367383
new_text = " ".join((text, word_at_index))
368384
return self.replace_word_at_index(index, new_text)
369385

370-
def get_deletion_indices(self):
371-
return self.attack_attrs["original_index_map"][
372-
self.attack_attrs["original_index_map"] == -1
373-
]
374-
375-
def generate_new_attacked_text(self, new_words):
386+
def generate_new_attacked_text(self, new_words: Iterable[str]) -> AttackedText:
376387
"""Returns a new AttackedText object and replaces old list of words
377388
with a new list of words, but preserves the punctuation and spacing of
378389
the original message.
@@ -466,15 +477,17 @@ def generate_new_attacked_text(self, new_words):
466477
)
467478
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
468479

469-
def words_diff_ratio(self, x):
480+
def words_diff_ratio(self, x: AttackedText) -> float:
470481
"""Get the ratio of words difference between current text and `x`.
471482
472483
Note that current text and `x` must have same number of words.
473484
"""
474485
assert self.num_words == x.num_words
475486
return float(np.sum(self.words != x.words)) / self.num_words
476487

477-
def align_with_model_tokens(self, model_wrapper):
488+
def align_with_model_tokens(
489+
self, model_wrapper: textattack.models.wrappers.ModelWrapper
490+
) -> Dict[int, Iterable[int]]:
478491
"""Align AttackedText's `words` with target model's tokenization scheme
479492
(e.g. word, character, subword). Specifically, we map each word to list
480493
of indices of tokens that compose the word (e.g. embedding --> ["em",
@@ -511,7 +524,7 @@ def align_with_model_tokens(self, model_wrapper):
511524
return word2token_mapping
512525

513526
@property
514-
def tokenizer_input(self):
527+
def tokenizer_input(self) -> Tuple[str]:
515528
"""The tuple of inputs to be passed to the tokenizer."""
516529
input_tuple = tuple(self._text_input.values())
517530
# Prefer to return a string instead of a tuple with a single value.
@@ -521,15 +534,15 @@ def tokenizer_input(self):
521534
return input_tuple
522535

523536
@property
524-
def column_labels(self):
537+
def column_labels(self) -> List[str]:
525538
"""Returns the labels for this text's columns.
526539
527540
For single-sequence inputs, this simply returns ['text'].
528541
"""
529542
return list(self._text_input.keys())
530543

531544
@property
532-
def words_per_input(self):
545+
def words_per_input(self) -> List[List[str]]:
533546
"""Returns a list of lists of words corresponding to each input."""
534547
if not self._words_per_input:
535548
self._words_per_input = [
@@ -538,32 +551,32 @@ def words_per_input(self):
538551
return self._words_per_input
539552

540553
@property
541-
def words(self):
554+
def words(self) -> List[str]:
542555
if not self._words:
543556
self._words = words_from_text(self.text)
544557
return self._words
545558

546559
@property
547-
def text(self):
560+
def text(self) -> str:
548561
"""Represents full text input.
549562
550563
Multiply inputs are joined with a line break.
551564
"""
552565
return "\n".join(self._text_input.values())
553566

554567
@property
555-
def num_words(self):
568+
def num_words(self) -> int:
556569
"""Returns the number of words in the sequence."""
557570
return len(self.words)
558571

559572
@property
560-
def newly_swapped_words(self):
573+
def newly_swapped_words(self) -> List[str]:
561574
return [
562575
self.attack_attrs["prev_attacked_text"].words[i]
563576
for i in self.attack_attrs["newly_modified_indices"]
564577
]
565578

566-
def printable_text(self, key_color="bold", key_color_method=None):
579+
def printable_text(self, key_color="bold", key_color_method=None) -> str:
567580
"""Represents full text input. Adds field descriptions.
568581
569582
For example, entailment inputs look like:
@@ -595,5 +608,5 @@ def ck(k):
595608
for key, value in self._text_input.items()
596609
)
597610

598-
def __repr__(self):
611+
def __repr__(self) -> str:
599612
return f'<AttackedText "{self.text}">'

0 commit comments

Comments
 (0)