66A helper class that represents a string that can be attacked.
77"""
88
9+ from __future__ import annotations
10+
911from collections import OrderedDict
1012import math
13+ from typing import Dict , Iterable , List , Optional , Set , Tuple
1114
1215import flair
1316from flair .data import Sentence
@@ -71,20 +74,21 @@ def __init__(self, text_input, attack_attrs=None):
7174 # A list of all indices in *this* text that have been modified.
7275 self .attack_attrs .setdefault ("modified_indices" , set ())
7376
74- def __eq__ (self , other ):
75- """Compares two text instances to make sure they have the same attack
76- attributes.
77+ def __eq__ (self , other : AttackedText ) -> bool :
78+ """Compares two AttackedText instances.
7779
78- Since some elements stored in ``self.attack_attrs`` may be numpy
79- arrays, we have to take special care when comparing them.
80+ Note: Does not compute true equality across attack attributes.
81+ We found this caused large performance issues with caching,
82+ and it's actually much faster (cache-wise) to just compare
83+ by the text, and this works for lots of use cases.
8084 """
8185 if not (self .text == other .text ):
8286 return False
8387 if len (self .attack_attrs ) != len (other .attack_attrs ):
8488 return False
8589 return True
8690
87- def __hash__ (self ):
91+ def __hash__ (self ) -> int :
8892 return hash (self .text )
8993
9094 def free_memory (self ):
@@ -102,7 +106,7 @@ def free_memory(self):
102106 if isinstance (self .attack_attrs [key ], torch .Tensor ):
103107 self .attack_attrs .pop (key , None )
104108
105- def text_window_around_index (self , index , window_size ) :
109+ def text_window_around_index (self , index : int , window_size : int ) -> str :
106110 """The text window of ``window_size`` words centered around
107111 ``index``."""
108112 length = self .num_words
@@ -120,10 +124,12 @@ def text_window_around_index(self, index, window_size):
120124 text_idx_end = self ._text_index_of_word_index (end ) + len (self .words [end ])
121125 return self .text [text_idx_start :text_idx_end ]
122126
123- def pos_of_word_index (self , desired_word_idx ) :
127+ def pos_of_word_index (self , desired_word_idx : int ) -> str :
124128 """Returns the part-of-speech of the word at index `word_idx`.
125129
126130 Uses FLAIR part-of-speech tagger.
131+
132+ Throws: ValueError, if no POS tag found for index.
127133 """
128134 if not self ._pos_tags :
129135 sentence = Sentence (
@@ -151,10 +157,12 @@ def pos_of_word_index(self, desired_word_idx):
151157 f"Did not find word from index { desired_word_idx } in flair POS tag"
152158 )
153159
154- def ner_of_word_index (self , desired_word_idx , model_name = "ner" ):
160+ def ner_of_word_index (self , desired_word_idx : int , model_name = "ner" ) -> str :
155161 """Returns the ner tag of the word at index `word_idx`.
156162
157163 Uses FLAIR ner tagger.
164+
165+ Throws: ValueError, if not NER tag found for index.
158166 """
159167 if not self ._ner_tags :
160168 sentence = Sentence (
@@ -179,7 +187,7 @@ def ner_of_word_index(self, desired_word_idx, model_name="ner"):
179187 f"Did not find word from index { desired_word_idx } in flair POS tag"
180188 )
181189
182- def _text_index_of_word_index (self , i ) :
190+ def _text_index_of_word_index (self , i : int ) -> int :
183191 """Returns the index of word ``i`` in self.text."""
184192 pre_words = self .words [: i + 1 ]
185193 lower_text = self .text .lower ()
@@ -192,20 +200,20 @@ def _text_index_of_word_index(self, i):
192200 look_after_index -= len (self .words [i ])
193201 return look_after_index
194202
195- def text_until_word_index (self , i ) :
203+ def text_until_word_index (self , i : int ) -> str :
196204 """Returns the text before the beginning of word at index ``i``."""
197205 look_after_index = self ._text_index_of_word_index (i )
198206 return self .text [:look_after_index ]
199207
200- def text_after_word_index (self , i ) :
208+ def text_after_word_index (self , i : int ) -> str :
201209 """Returns the text after the end of word at index ``i``."""
202210 # Get index of beginning of word then jump to end of word.
203211 look_after_index = self ._text_index_of_word_index (i ) + len (self .words [i ])
204212 return self .text [look_after_index :]
205213
206- def first_word_diff (self , other_attacked_text ) :
214+ def first_word_diff (self , other_attacked_text : AttackedText ) -> Optional [ str ] :
207215 """Returns the first word in self.words that differs from
208- other_attacked_text.
216+ other_attacked_text, or None if all words are the same .
209217
210218 Useful for word swap strategies.
211219 """
@@ -216,7 +224,7 @@ def first_word_diff(self, other_attacked_text):
216224 return w1 [i ]
217225 return None
218226
219- def first_word_diff_index (self , other_attacked_text ) :
227+ def first_word_diff_index (self , other_attacked_text : AttackedText ) -> Optional [ int ] :
220228 """Returns the index of the first word in self.words that differs from
221229 other_attacked_text.
222230
@@ -229,7 +237,7 @@ def first_word_diff_index(self, other_attacked_text):
229237 return i
230238 return None
231239
232- def all_words_diff (self , other_attacked_text ) :
240+ def all_words_diff (self , other_attacked_text : AttackedText ) -> Set [ int ] :
233241 """Returns the set of indices for which this and other_attacked_text
234242 have different words."""
235243 indices = set ()
@@ -240,16 +248,17 @@ def all_words_diff(self, other_attacked_text):
240248 indices .add (i )
241249 return indices
242250
243- def ith_word_diff (self , other_attacked_text , i ) :
244- """Returns whether the word at index i differs from
251+ def ith_word_diff (self , other_attacked_text : AttackedText , i : int ) -> bool :
252+ """Returns bool representing whether the word at index i differs from
245253 other_attacked_text."""
246254 w1 = self .words
247255 w2 = other_attacked_text .words
248256 if len (w1 ) - 1 < i or len (w2 ) - 1 < i :
249257 return True
250258 return w1 [i ] != w2 [i ]
251259
252- def words_diff_num (self , other_attacked_text ):
260+ def words_diff_num (self , other_attacked_text : AttackedText ) -> int :
261+ """The number of words different between two AttackedText objects."""
253262 # using edit distance to calculate words diff num
254263 def generate_tokens (words ):
255264 result = {}
@@ -295,7 +304,7 @@ def cal_dif(w1, w2):
295304 w2 = other_attacked_text .words
296305 return cal_dif (w1 , w2 )
297306
298- def convert_from_original_idxs (self , idxs ) :
307+ def convert_from_original_idxs (self , idxs : Iterable [ int ]) -> List [ int ] :
299308 """Takes indices of words from original string and converts them to
300309 indices of the same words in the current string.
301310
@@ -315,9 +324,16 @@ def convert_from_original_idxs(self, idxs):
315324
316325 return [self .attack_attrs ["original_index_map" ][i ] for i in idxs ]
317326
318- def replace_words_at_indices (self , indices , new_words ):
319- """This code returns a new AttackedText object where the word at
320- ``index`` is replaced with a new word."""
327+ def get_deletion_indices (self ) -> Iterable [int ]:
328+ return self .attack_attrs ["original_index_map" ][
329+ self .attack_attrs ["original_index_map" ] == - 1
330+ ]
331+
332+ def replace_words_at_indices (
333+ self , indices : Iterable [int ], new_words : Iterable [str ]
334+ ) -> AttackedText :
335+ """Returns a new AttackedText object where the word at ``index`` is
336+ replaced with a new word."""
321337 if len (indices ) != len (new_words ):
322338 raise ValueError (
323339 f"Cannot replace { len (new_words )} words at { len (indices )} indices."
@@ -333,21 +349,21 @@ def replace_words_at_indices(self, indices, new_words):
333349 words [i ] = new_word
334350 return self .generate_new_attacked_text (words )
335351
336- def replace_word_at_index (self , index , new_word ) :
337- """This code returns a new AttackedText object where the word at
338- ``index`` is replaced with a new word."""
352+ def replace_word_at_index (self , index : int , new_word : str ) -> AttackedText :
353+ """Returns a new AttackedText object where the word at ``index`` is
354+ replaced with a new word."""
339355 if not isinstance (new_word , str ):
340356 raise TypeError (
341357 f"replace_word_at_index requires ``str`` new_word, got { type (new_word )} "
342358 )
343359 return self .replace_words_at_indices ([index ], [new_word ])
344360
345- def delete_word_at_index (self , index ) :
346- """This code returns a new AttackedText object where the word at
347- ``index`` is removed."""
361+ def delete_word_at_index (self , index : int ) -> AttackedText :
362+ """Returns a new AttackedText object where the word at ``index`` is
363+ removed."""
348364 return self .replace_word_at_index (index , "" )
349365
350- def insert_text_after_word_index (self , index , text ) :
366+ def insert_text_after_word_index (self , index : int , text : str ) -> AttackedText :
351367 """Inserts a string before word at index ``index`` and attempts to add
352368 appropriate spacing."""
353369 if not isinstance (text , str ):
@@ -356,7 +372,7 @@ def insert_text_after_word_index(self, index, text):
356372 new_text = " " .join ((word_at_index , text ))
357373 return self .replace_word_at_index (index , new_text )
358374
359- def insert_text_before_word_index (self , index , text ) :
375+ def insert_text_before_word_index (self , index : int , text : str ) -> AttackedText :
360376 """Inserts a string before word at index ``index`` and attempts to add
361377 appropriate spacing."""
362378 if not isinstance (text , str ):
@@ -367,12 +383,7 @@ def insert_text_before_word_index(self, index, text):
367383 new_text = " " .join ((text , word_at_index ))
368384 return self .replace_word_at_index (index , new_text )
369385
370- def get_deletion_indices (self ):
371- return self .attack_attrs ["original_index_map" ][
372- self .attack_attrs ["original_index_map" ] == - 1
373- ]
374-
375- def generate_new_attacked_text (self , new_words ):
386+ def generate_new_attacked_text (self , new_words : Iterable [str ]) -> AttackedText :
376387 """Returns a new AttackedText object and replaces old list of words
377388 with a new list of words, but preserves the punctuation and spacing of
378389 the original message.
@@ -466,15 +477,17 @@ def generate_new_attacked_text(self, new_words):
466477 )
467478 return AttackedText (perturbed_input , attack_attrs = new_attack_attrs )
468479
469- def words_diff_ratio (self , x ) :
480+ def words_diff_ratio (self , x : AttackedText ) -> float :
470481 """Get the ratio of words difference between current text and `x`.
471482
472483 Note that current text and `x` must have same number of words.
473484 """
474485 assert self .num_words == x .num_words
475486 return float (np .sum (self .words != x .words )) / self .num_words
476487
477- def align_with_model_tokens (self , model_wrapper ):
488+ def align_with_model_tokens (
489+ self , model_wrapper : textattack .models .wrappers .ModelWrapper
490+ ) -> Dict [int , Iterable [int ]]:
478491 """Align AttackedText's `words` with target model's tokenization scheme
479492 (e.g. word, character, subword). Specifically, we map each word to list
480493 of indices of tokens that compose the word (e.g. embedding --> ["em",
@@ -511,7 +524,7 @@ def align_with_model_tokens(self, model_wrapper):
511524 return word2token_mapping
512525
513526 @property
514- def tokenizer_input (self ):
527+ def tokenizer_input (self ) -> Tuple [ str ] :
515528 """The tuple of inputs to be passed to the tokenizer."""
516529 input_tuple = tuple (self ._text_input .values ())
517530 # Prefer to return a string instead of a tuple with a single value.
@@ -521,15 +534,15 @@ def tokenizer_input(self):
521534 return input_tuple
522535
523536 @property
524- def column_labels (self ):
537+ def column_labels (self ) -> List [ str ] :
525538 """Returns the labels for this text's columns.
526539
527540 For single-sequence inputs, this simply returns ['text'].
528541 """
529542 return list (self ._text_input .keys ())
530543
531544 @property
532- def words_per_input (self ):
545+ def words_per_input (self ) -> List [ List [ str ]] :
533546 """Returns a list of lists of words corresponding to each input."""
534547 if not self ._words_per_input :
535548 self ._words_per_input = [
@@ -538,32 +551,32 @@ def words_per_input(self):
538551 return self ._words_per_input
539552
540553 @property
541- def words (self ):
554+ def words (self ) -> List [ str ] :
542555 if not self ._words :
543556 self ._words = words_from_text (self .text )
544557 return self ._words
545558
546559 @property
547- def text (self ):
560+ def text (self ) -> str :
548561 """Represents full text input.
549562
550563 Multiply inputs are joined with a line break.
551564 """
552565 return "\n " .join (self ._text_input .values ())
553566
554567 @property
555- def num_words (self ):
568+ def num_words (self ) -> int :
556569 """Returns the number of words in the sequence."""
557570 return len (self .words )
558571
559572 @property
560- def newly_swapped_words (self ):
573+ def newly_swapped_words (self ) -> List [ str ] :
561574 return [
562575 self .attack_attrs ["prev_attacked_text" ].words [i ]
563576 for i in self .attack_attrs ["newly_modified_indices" ]
564577 ]
565578
566- def printable_text (self , key_color = "bold" , key_color_method = None ):
579+ def printable_text (self , key_color = "bold" , key_color_method = None ) -> str :
567580 """Represents full text input. Adds field descriptions.
568581
569582 For example, entailment inputs look like:
@@ -595,5 +608,5 @@ def ck(k):
595608 for key , value in self ._text_input .items ()
596609 )
597610
598- def __repr__ (self ):
611+ def __repr__ (self ) -> str :
599612 return f'<AttackedText "{ self .text } ">'
0 commit comments