Skip to content

Commit 95cf68e

Browse files
committed
Fix 7 test failures: compatibility bugs and missing skip markers
Code fixes: - Replace removed transformers.optimization.AdamW with torch.optim.AdamW in trainer.py (removed in transformers>=4.x) - Use AutoTokenizer/AutoModelForMaskedLM instead of BertTokenizer/BertForMaskedLM in ChineseWordSwapMaskedLM, since xlm-roberta-base requires its own tokenizer - Fix hardcoded CUDA device in ChineseWordSwapMaskedLM to auto-detect device Test fixes: - Update stale expected output for list_augmentation_recipes to include BackTranscriptionAugmenter - Add pytest.skip for tests requiring tensorflow_hub when not installed (interactive_mode, adv_metrics attack tests, train test) - Add pytest.skipif for test_embedding_gensim when gensim not installed - Replace deprecated gensim Word2VecKeyedVectors API with KeyedVectors
1 parent a3e64f5 commit 95cf68e

6 files changed

Lines changed: 27 additions & 10 deletions

File tree

tests/sample_outputs/list_augmentation_recipes.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
back_trans (textattack.augmentation.BackTranslationAugmenter)
2+
back_transcription (textattack.augmentation.BackTranscriptionAugmenter)
23
charswap (textattack.augmentation.CharSwapAugmenter)
34
checklist (textattack.augmentation.CheckListAugmenter)
45
clare (textattack.augmentation.CLAREAugmenter)

tests/test_command_line/test_attack.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import importlib
12
import pdb
23
import re
34

45
from helpers import run_command_and_get_result
56
import pytest
67

8+
_tensorflow_hub_available = importlib.util.find_spec("tensorflow_hub") is not None
9+
710
DEBUG = False
811
"""Attack command-line tests in the format (name, args, sample_output_file)"""
912

@@ -171,6 +174,9 @@
171174
@pytest.mark.slow
172175
def test_command_line_attack(name, command, sample_output_file):
173176
"""Runs attack tests and compares their outputs to a reference file."""
177+
_tf_hub_tests = {"interactive_mode", "attack_from_transformers_adv_metrics", "run_attack_hotflip_lstm_mr_4_adv_metrics"}
178+
if name in _tf_hub_tests and not _tensorflow_hub_available:
179+
pytest.skip("tensorflow_hub is not installed")
174180
# read in file and create regex
175181
desired_output = open(sample_output_file, "r").read().strip()
176182
print("desired_output.encoded =>", desired_output.encode())

tests/test_command_line/test_train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import importlib
12
import os
23
import re
34

45
from helpers import run_command_and_get_result
6+
import pytest
57

8+
_tensorflow_hub_available = importlib.util.find_spec("tensorflow_hub") is not None
69

10+
11+
@pytest.mark.skipif(not _tensorflow_hub_available, reason="tensorflow_hub is not installed")
712
def test_train_tiny():
813
command = "textattack train --model distilbert-base-uncased --attack textfooler --dataset rotten_tomatoes --model-max-length 64 --num-epochs 1 --num-clean-epochs 0 --num-train-adv-examples 2"
914

tests/test_word_embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import importlib
12
import os
23

34
import numpy as np
45
import pytest
56

67
from textattack.shared import GensimWordEmbedding, WordEmbedding
78

9+
_gensim_available = importlib.util.find_spec("gensim") is not None
10+
811

912
def test_embedding_paragramcf():
1013
word_embedding = WordEmbedding.counterfitted_GLOVE_embedding()
@@ -13,6 +16,7 @@ def test_embedding_paragramcf():
1316
assert word_embedding[10**9] is None
1417

1518

19+
@pytest.mark.skipif(not _gensim_available, reason="gensim is not installed")
1620
def test_embedding_gensim():
1721
# download a trained word2vec model
1822
from textattack.shared.utils import LazyLoader
@@ -30,10 +34,9 @@ def test_embedding_gensim():
3034
)
3135
f.close()
3236

33-
gensim = LazyLoader("gensim", globals(), "gensim")
34-
keyed_vectors = (
35-
gensim.models.keyedvectors.Word2VecKeyedVectors.load_word2vec_format(path)
36-
)
37+
from gensim.models import KeyedVectors
38+
39+
keyed_vectors = KeyedVectors.load_word2vec_format(path)
3740
word_embedding = GensimWordEmbedding(keyed_vectors)
3841
assert pytest.approx(word_embedding[0][0]) == 1
3942
assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2)

textattack/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def get_optimizer_and_scheduler(self, model, num_training_steps):
361361
},
362362
]
363363

364-
optimizer = transformers.optimization.AdamW(
364+
optimizer = torch.optim.AdamW(
365365
optimizer_grouped_parameters, lr=self.training_args.learning_rate
366366
)
367367
if isinstance(self.training_args.num_warmup_steps, float):

textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ class ChineseWordSwapMaskedLM(WordSwap):
1313
model."""
1414

1515
def __init__(self, task="fill-mask", model="xlm-roberta-base", **kwargs):
16-
from transformers import BertForMaskedLM, BertTokenizer
16+
from transformers import AutoModelForMaskedLM, AutoTokenizer
1717

18-
self.tt = BertTokenizer.from_pretrained(model)
19-
self.mm = BertForMaskedLM.from_pretrained(model)
20-
self.mm.to("cuda")
18+
self.tt = AutoTokenizer.from_pretrained(model)
19+
self.mm = AutoModelForMaskedLM.from_pretrained(model)
20+
device = "cuda" if torch.cuda.is_available() else "cpu"
21+
self.mm.to(device)
22+
self._device = device
2123
super().__init__(**kwargs)
2224

2325
def get_replacement_words(self, current_text, indice_to_modify):
@@ -26,7 +28,7 @@ def get_replacement_words(self, current_text, indice_to_modify):
2628
) # 修改前<mask>,xlmrberta的模型
2729
tokens = self.tt.tokenize(masked_text.text)
2830
input_ids = self.tt.convert_tokens_to_ids(tokens)
29-
input_tensor = torch.tensor([input_ids]).to("cuda")
31+
input_tensor = torch.tensor([input_ids]).to(self._device)
3032
with torch.no_grad():
3133
outputs = self.mm(input_tensor)
3234
predictions = outputs.logits

0 commit comments

Comments
 (0)