66A helper class that represents a string that can be attacked.
77"""
88
9+ from __future__ import annotations
10+
911from collections import OrderedDict
1012import math
13+ from typing import Iterable , List , Optional , Set , Tuple
1114
1215import flair
1316from flair .data import Sentence
@@ -71,31 +74,17 @@ def __init__(self, text_input, attack_attrs=None):
7174 # A list of all indices in *this* text that have been modified.
7275 self .attack_attrs .setdefault ("modified_indices" , set ())
7376
74- def __eq__ (self , other ):
75- """Compares two text instances to make sure they have the same attack
76- attributes.
77+ def __eq__ (self , other : AttackedText ) -> bool :
78+ """Compares two AttackedText instances.
7779
78- Since some elements stored in ``self.attack_attrs`` may be numpy
79- arrays, we have to take special care when comparing them.
80+ Note: Does not compute true equality across attack attributes.
81+ We found this caused large performance issues with caching,
82+ and it's actually much faster (cache-wise) to just compare
83+ by the text, and this works for lots of use cases.
8084 """
81- if not (self .text == other .text ):
82- return False
83- if len (self .attack_attrs ) != len (other .attack_attrs ):
84- return False
85- for key in self .attack_attrs :
86- if key not in other .attack_attrs :
87- return False
88- elif isinstance (self .attack_attrs [key ], np .ndarray ):
89- if not (self .attack_attrs [key ].shape == other .attack_attrs [key ].shape ):
90- return False
91- elif not (self .attack_attrs [key ] == other .attack_attrs [key ]).all ():
92- return False
93- else :
94- if not self .attack_attrs [key ] == other .attack_attrs [key ]:
95- return False
96- return True
85+ return self .text == other .text
9786
98- def __hash__ (self ):
87+ def __hash__ (self ) -> int :
9988 return hash (self .text )
10089
10190 def free_memory (self ):
@@ -113,7 +102,7 @@ def free_memory(self):
113102 if isinstance (self .attack_attrs [key ], torch .Tensor ):
114103 self .attack_attrs .pop (key , None )
115104
116- def text_window_around_index (self , index , window_size ) :
105+ def text_window_around_index (self , index : int , window_size : int ) -> str :
117106 """The text window of ``window_size`` words centered around
118107 ``index``."""
119108 length = self .num_words
@@ -131,10 +120,12 @@ def text_window_around_index(self, index, window_size):
131120 text_idx_end = self ._text_index_of_word_index (end ) + len (self .words [end ])
132121 return self .text [text_idx_start :text_idx_end ]
133122
134- def pos_of_word_index (self , desired_word_idx ) :
123+ def pos_of_word_index (self , desired_word_idx : int ) -> str :
135124 """Returns the part-of-speech of the word at index `word_idx`.
136125
137126 Uses FLAIR part-of-speech tagger.
127+
128+ Throws: ValueError, if no POS tag found for index.
138129 """
139130 if not self ._pos_tags :
140131 sentence = Sentence (
@@ -162,10 +153,12 @@ def pos_of_word_index(self, desired_word_idx):
162153 f"Did not find word from index { desired_word_idx } in flair POS tag"
163154 )
164155
165- def ner_of_word_index (self , desired_word_idx , model_name = "ner" ):
156+ def ner_of_word_index (self , desired_word_idx : int , model_name = "ner" ) -> str :
166157 """Returns the ner tag of the word at index `word_idx`.
167158
168159 Uses FLAIR ner tagger.
160+
161+ Throws: ValueError, if not NER tag found for index.
169162 """
170163 if not self ._ner_tags :
171164 sentence = Sentence (
@@ -190,7 +183,7 @@ def ner_of_word_index(self, desired_word_idx, model_name="ner"):
190183 f"Did not find word from index { desired_word_idx } in flair POS tag"
191184 )
192185
193- def _text_index_of_word_index (self , i ) :
186+ def _text_index_of_word_index (self , i : int ) -> int :
194187 """Returns the index of word ``i`` in self.text."""
195188 pre_words = self .words [: i + 1 ]
196189 lower_text = self .text .lower ()
@@ -203,20 +196,20 @@ def _text_index_of_word_index(self, i):
203196 look_after_index -= len (self .words [i ])
204197 return look_after_index
205198
206- def text_until_word_index (self , i ) :
199+ def text_until_word_index (self , i : int ) -> str :
207200 """Returns the text before the beginning of word at index ``i``."""
208201 look_after_index = self ._text_index_of_word_index (i )
209202 return self .text [:look_after_index ]
210203
211- def text_after_word_index (self , i ) :
204+ def text_after_word_index (self , i : int ) -> str :
212205 """Returns the text after the end of word at index ``i``."""
213206 # Get index of beginning of word then jump to end of word.
214207 look_after_index = self ._text_index_of_word_index (i ) + len (self .words [i ])
215208 return self .text [look_after_index :]
216209
217- def first_word_diff (self , other_attacked_text ) :
210+ def first_word_diff (self , other_attacked_text : AttackedText ) -> Optional [ str ] :
218211 """Returns the first word in self.words that differs from
219- other_attacked_text.
212+ other_attacked_text, or None if all words are the same .
220213
221214 Useful for word swap strategies.
222215 """
@@ -227,7 +220,7 @@ def first_word_diff(self, other_attacked_text):
227220 return w1 [i ]
228221 return None
229222
230- def first_word_diff_index (self , other_attacked_text ) :
223+ def first_word_diff_index (self , other_attacked_text : AttackedText ) -> Optional [ int ] :
231224 """Returns the index of the first word in self.words that differs from
232225 other_attacked_text.
233226
@@ -240,7 +233,7 @@ def first_word_diff_index(self, other_attacked_text):
240233 return i
241234 return None
242235
243- def all_words_diff (self , other_attacked_text ) :
236+ def all_words_diff (self , other_attacked_text : AttackedText ) -> Set [ int ] :
244237 """Returns the set of indices for which this and other_attacked_text
245238 have different words."""
246239 indices = set ()
@@ -251,16 +244,17 @@ def all_words_diff(self, other_attacked_text):
251244 indices .add (i )
252245 return indices
253246
254- def ith_word_diff (self , other_attacked_text , i ) :
255- """Returns whether the word at index i differs from
247+ def ith_word_diff (self , other_attacked_text : AttackedText , i : int ) -> bool :
248+ """Returns bool representing whether the word at index i differs from
256249 other_attacked_text."""
257250 w1 = self .words
258251 w2 = other_attacked_text .words
259252 if len (w1 ) - 1 < i or len (w2 ) - 1 < i :
260253 return True
261254 return w1 [i ] != w2 [i ]
262255
263- def words_diff_num (self , other_attacked_text ):
256+ def words_diff_num (self , other_attacked_text : AttackedText ) -> int :
257+ """The number of words different between two AttackedText objects."""
264258 # using edit distance to calculate words diff num
265259 def generate_tokens (words ):
266260 result = {}
@@ -306,7 +300,7 @@ def cal_dif(w1, w2):
306300 w2 = other_attacked_text .words
307301 return cal_dif (w1 , w2 )
308302
309- def convert_from_original_idxs (self , idxs ) :
303+ def convert_from_original_idxs (self , idxs : Iterable [ int ]) -> List [ int ] :
310304 """Takes indices of words from original string and converts them to
311305 indices of the same words in the current string.
312306
@@ -326,9 +320,16 @@ def convert_from_original_idxs(self, idxs):
326320
327321 return [self .attack_attrs ["original_index_map" ][i ] for i in idxs ]
328322
329- def replace_words_at_indices (self , indices , new_words ):
330- """This code returns a new AttackedText object where the word at
331- ``index`` is replaced with a new word."""
323+ def get_deletion_indices (self ) -> Iterable [int ]:
324+ return self .attack_attrs ["original_index_map" ][
325+ self .attack_attrs ["original_index_map" ] == - 1
326+ ]
327+
328+ def replace_words_at_indices (
329+ self , indices : Iterable [int ], new_words : Iterable [str ]
330+ ) -> AttackedText :
331+ """Returns a new AttackedText object where the word at ``index`` is
332+ replaced with a new word."""
332333 if len (indices ) != len (new_words ):
333334 raise ValueError (
334335 f"Cannot replace { len (new_words )} words at { len (indices )} indices."
@@ -344,21 +345,21 @@ def replace_words_at_indices(self, indices, new_words):
344345 words [i ] = new_word
345346 return self .generate_new_attacked_text (words )
346347
347- def replace_word_at_index (self , index , new_word ) :
348- """This code returns a new AttackedText object where the word at
349- ``index`` is replaced with a new word."""
348+ def replace_word_at_index (self , index : int , new_word : str ) -> AttackedText :
349+ """Returns a new AttackedText object where the word at ``index`` is
350+ replaced with a new word."""
350351 if not isinstance (new_word , str ):
351352 raise TypeError (
352353 f"replace_word_at_index requires ``str`` new_word, got { type (new_word )} "
353354 )
354355 return self .replace_words_at_indices ([index ], [new_word ])
355356
356- def delete_word_at_index (self , index ) :
357- """This code returns a new AttackedText object where the word at
358- ``index`` is removed."""
357+ def delete_word_at_index (self , index : int ) -> AttackedText :
358+ """Returns a new AttackedText object where the word at ``index`` is
359+ removed."""
359360 return self .replace_word_at_index (index , "" )
360361
361- def insert_text_after_word_index (self , index , text ) :
362+ def insert_text_after_word_index (self , index : int , text : str ) -> AttackedText :
362363 """Inserts a string before word at index ``index`` and attempts to add
363364 appropriate spacing."""
364365 if not isinstance (text , str ):
@@ -367,7 +368,7 @@ def insert_text_after_word_index(self, index, text):
367368 new_text = " " .join ((word_at_index , text ))
368369 return self .replace_word_at_index (index , new_text )
369370
370- def insert_text_before_word_index (self , index , text ) :
371+ def insert_text_before_word_index (self , index : int , text : str ) -> AttackedText :
371372 """Inserts a string before word at index ``index`` and attempts to add
372373 appropriate spacing."""
373374 if not isinstance (text , str ):
@@ -378,12 +379,7 @@ def insert_text_before_word_index(self, index, text):
378379 new_text = " " .join ((text , word_at_index ))
379380 return self .replace_word_at_index (index , new_text )
380381
381- def get_deletion_indices (self ):
382- return self .attack_attrs ["original_index_map" ][
383- self .attack_attrs ["original_index_map" ] == - 1
384- ]
385-
386- def generate_new_attacked_text (self , new_words ):
382+ def generate_new_attacked_text (self , new_words : Iterable [str ]) -> AttackedText :
387383 """Returns a new AttackedText object and replaces old list of words
388384 with a new list of words, but preserves the punctuation and spacing of
389385 the original message.
@@ -480,15 +476,17 @@ def generate_new_attacked_text(self, new_words):
480476 )
481477 return AttackedText (perturbed_input , attack_attrs = new_attack_attrs )
482478
483- def words_diff_ratio (self , x ) :
479+ def words_diff_ratio (self , x : AttackedText ) -> float :
484480 """Get the ratio of words difference between current text and `x`.
485481
486482 Note that current text and `x` must have same number of words.
487483 """
488484 assert self .num_words == x .num_words
489485 return float (np .sum (self .words != x .words )) / self .num_words
490486
491- def align_with_model_tokens (self , model_wrapper ):
487+ def align_with_model_tokens (
488+ self , model_wrapper : textattack .models .wrappers .ModelWrapper
489+ ) -> Dict [int , Iterable [int ]]:
492490 """Align AttackedText's `words` with target model's tokenization scheme
493491 (e.g. word, character, subword). Specifically, we map each word to list
494492 of indices of tokens that compose the word (e.g. embedding --> ["em",
@@ -525,7 +523,7 @@ def align_with_model_tokens(self, model_wrapper):
525523 return word2token_mapping
526524
527525 @property
528- def tokenizer_input (self ):
526+ def tokenizer_input (self ) -> Tuple [ str ] :
529527 """The tuple of inputs to be passed to the tokenizer."""
530528 input_tuple = tuple (self ._text_input .values ())
531529 # Prefer to return a string instead of a tuple with a single value.
@@ -535,15 +533,15 @@ def tokenizer_input(self):
535533 return input_tuple
536534
537535 @property
538- def column_labels (self ):
536+ def column_labels (self ) -> List [ str ] :
539537 """Returns the labels for this text's columns.
540538
541539 For single-sequence inputs, this simply returns ['text'].
542540 """
543541 return list (self ._text_input .keys ())
544542
545543 @property
546- def words_per_input (self ):
544+ def words_per_input (self ) -> List [ List [ str ]] :
547545 """Returns a list of lists of words corresponding to each input."""
548546 if not self ._words_per_input :
549547 self ._words_per_input = [
@@ -552,29 +550,29 @@ def words_per_input(self):
552550 return self ._words_per_input
553551
554552 @property
555- def words (self ):
553+ def words (self ) -> List [ str ] :
556554 if not self ._words :
557555 self ._words = words_from_text (self .text )
558556 return self ._words
559557
560558 @property
561- def text (self ):
559+ def text (self ) -> str :
562560 """Represents full text input.
563561
564562 Multiply inputs are joined with a line break.
565563 """
566564 return "\n " .join (self ._text_input .values ())
567565
568566 @property
569- def num_words (self ):
567+ def num_words (self ) -> int :
570568 """Returns the number of words in the sequence."""
571569 return len (self .words )
572570
573571 @property
574- def newly_swapped_words (self ):
572+ def newly_swapped_words (self ) -> List [ str ] :
575573 return [self .words [i ] for i in self .attack_attrs ["newly_modified_indices" ]]
576574
577- def printable_text (self , key_color = "bold" , key_color_method = None ):
575+ def printable_text (self , key_color = "bold" , key_color_method = None ) -> str :
578576 """Represents full text input. Adds field descriptions.
579577
580578 For example, entailment inputs look like:
@@ -606,5 +604,5 @@ def ck(k):
606604 for key , value in self ._text_input .items ()
607605 )
608606
609- def __repr__ (self ):
607+ def __repr__ (self ) -> str :
610608 return f'<AttackedText "{ self .text } ">'
0 commit comments