Skip to content

Commit 6e7fda5

Browse files
authored
Merge pull request #644 from QData/fixes-and-improvements
Bug fixes and tweaks
2 parents b433fcb + 9b20928 commit 6e7fda5

22 files changed

Lines changed: 84 additions & 34 deletions

File tree

.github/workflows/check-formatting.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ jobs:
2626
- name: Install dependencies
2727
run: |
2828
python -m pip install --upgrade pip setuptools wheel
29-
pip install black flake8 isort # Testing packages
3029
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
3130
pip install -e .[dev]
31+
pip install black flake8 isort --upgrade # Testing packages
3232
- name: Check code format with black and isort
3333
run: |
3434
make lint

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ Follow these steps to start contributing:
118118
```bash
119119
$ cd TextAttack
120120
$ pip install -e . ".[dev]"
121-
$ pip install black isort pytest pytest-xdist
121+
$ pip install black docformatter isort pytest pytest-xdist
122122
```
123123

124124
This will install `textattack` in editable mode and install `black` and

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
author = "UVA QData Lab"
2222

2323
# The full version, including alpha/beta/rc tags
24-
release = "0.3.4"
24+
release = "0.3.5"
2525

2626
# Set master doc to `index.rst`.
2727
master_doc = "index"

tests/test_attacked_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_window_around_index(self, attacked_text):
7070

7171
def test_big_window_around_index(self, attacked_text):
7272
assert (
73-
attacked_text.text_window_around_index(0, 10 ** 5) + "."
73+
attacked_text.text_window_around_index(0, 10**5) + "."
7474
) == attacked_text.text
7575

7676
def test_window_around_index_start(self, attacked_text):

tests/test_word_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_embedding_paragramcf():
1010
word_embedding = WordEmbedding.counterfitted_GLOVE_embedding()
1111
assert pytest.approx(word_embedding[0][0]) == -0.022007
1212
assert pytest.approx(word_embedding["fawn"][0]) == -0.022007
13-
assert word_embedding[10 ** 9] is None
13+
assert word_embedding[10**9] is None
1414

1515

1616
def test_embedding_gensim():
@@ -37,7 +37,7 @@ def test_embedding_gensim():
3737
word_embedding = GensimWordEmbedding(keyed_vectors)
3838
assert pytest.approx(word_embedding[0][0]) == 1
3939
assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2)
40-
assert word_embedding[10 ** 9] is None
40+
assert word_embedding[10**9] is None
4141

4242
# test query functionality
4343
assert pytest.approx(word_embedding.get_cos_sim(1, 3)) == 0

textattack/attack.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def __init__(
8181
constraints: List[Union[Constraint, PreTransformationConstraint]],
8282
transformation: Transformation,
8383
search_method: SearchMethod,
84-
transformation_cache_size=2 ** 15,
85-
constraint_cache_size=2 ** 15,
84+
transformation_cache_size=2**15,
85+
constraint_cache_size=2**15,
8686
):
8787
"""Initialize an attack object.
8888
@@ -371,22 +371,23 @@ def _attack(self, initial_result):
371371
final_result = self.search_method(initial_result)
372372
self.clear_cache()
373373
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
374-
return SuccessfulAttackResult(
374+
result = SuccessfulAttackResult(
375375
initial_result,
376376
final_result,
377377
)
378378
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
379-
return FailedAttackResult(
379+
result = FailedAttackResult(
380380
initial_result,
381381
final_result,
382382
)
383383
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
384-
return MaximizedAttackResult(
384+
result = MaximizedAttackResult(
385385
initial_result,
386386
final_result,
387387
)
388388
else:
389389
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
390+
return result
390391

391392
def attack(self, example, ground_truth_output):
392393
"""Attack a single example.

textattack/attack_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,8 @@ class _CommandLineAttackArgs:
478478
interactive: bool = False
479479
parallel: bool = False
480480
model_batch_size: int = 32
481-
model_cache_size: int = 2 ** 18
482-
constraint_cache_size: int = 2 ** 18
481+
model_cache_size: int = 2**18
482+
constraint_cache_size: int = 2**18
483483

484484
@classmethod
485485
def _add_parser_args(cls, parser):

textattack/constraints/grammaticality/cola.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343

4444
self.max_diff = max_diff
4545
self.model_name = model_name
46-
self._reference_score_cache = lru.LRU(2 ** 10)
46+
self._reference_score_cache = lru.LRU(2**10)
4747
model = AutoModelForSequenceClassification.from_pretrained(model_name)
4848
tokenizer = AutoTokenizer.from_pretrained(model_name)
4949
self.model = HuggingFaceModelWrapper(model, tokenizer)

textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self):
4949
self.sess, self.graph, self.PBTXT_PATH, self.CKPT_PATH
5050
)
5151

52-
self.lm_cache = lru.LRU(2 ** 18)
52+
self.lm_cache = lru.LRU(2**18)
5353

5454
def clear_cache(self):
5555
self.lm_cache.clear()

textattack/constraints/grammaticality/part_of_speech.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
self.language_nltk = language_nltk
5757
self.language_stanza = language_stanza
5858

59-
self._pos_tag_cache = lru.LRU(2 ** 14)
59+
self._pos_tag_cache = lru.LRU(2**14)
6060
if tagger_type == "flair":
6161
if tagset == "universal":
6262
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
@@ -93,7 +93,8 @@ def _get_pos(self, before_ctx, word, after_ctx):
9393

9494
if self.tagger_type == "flair":
9595
context_key_sentence = Sentence(
96-
context_key, use_tokenizer=textattack.shared.utils.words_from_text
96+
context_key,
97+
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
9798
)
9899
self._flair_pos_tagger.predict(context_key_sentence)
99100
word_list, pos_list = textattack.shared.utils.zip_flair_result(

0 commit comments

Comments
 (0)