@@ -82,7 +82,11 @@ def __eq__(self, other: AttackedText) -> bool:
8282 and it's actually much faster (cache-wise) to just compare
8383 by the text, and this works for lots of use cases.
8484 """
85- return self .text == other .text
85+ if not (self .text == other .text ):
86+ return False
87+ if len (self .attack_attrs ) != len (other .attack_attrs ):
88+ return False
89+ return True
8690
8791 def __hash__ (self ) -> int :
8892 return hash (self .text )
@@ -466,9 +470,6 @@ def generate_new_attacked_text(self, new_words: Iterable[str]) -> AttackedText:
466470 perturbed_text += adv_word_seq
467471 perturbed_text += original_text # Add all of the ending punctuation.
468472
469- # Add pointer to self so chain of replacements can be reconstructed.
470- new_attack_attrs ["prev_attacked_text" ] = self
471-
472473 # Reform perturbed_text into an OrderedDict.
473474 perturbed_input_texts = perturbed_text .split (AttackedText .SPLIT_TOKEN )
474475 perturbed_input = OrderedDict (
@@ -570,7 +571,10 @@ def num_words(self) -> int:
570571
571572 @property
572573 def newly_swapped_words (self ) -> List [str ]:
573- return [self .words [i ] for i in self .attack_attrs ["newly_modified_indices" ]]
574+ return [
575+ self .attack_attrs ["prev_attacked_text" ].words [i ]
576+ for i in self .attack_attrs ["newly_modified_indices" ]
577+ ]
574578
575579 def printable_text (self , key_color = "bold" , key_color_method = None ) -> str :
576580 """Represents full text input. Adds field descriptions.
0 commit comments