Skip to content

Commit c30728b

Browse files
authored
Merge pull request #505 from QData/s3-model-fix
[FixBug] Fix bug with loading pretrained lstm and cnn models
2 parents 9d7b3b9 + d92203c commit c30728b

4 files changed

Lines changed: 22 additions & 14 deletions

File tree

textattack/attacker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ def _attack_parallel(self):
397397
def attack_dataset(self):
398398
"""Attack the dataset.
399399
400-
401400
Returns:
402401
:obj:`list[AttackResult]` - List of :class:`~textattack.attack_results.AttackResult` obtained after attacking the given dataset..
403402
"""

textattack/model_args.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,27 @@
9696

9797
#
9898
# Models hosted by textattack.
99+
# `models` vs `models_v2`: `models_v2` is simply a new dir in S3 that contains models' `config.json`.
100+
# Fixes issue https://github.com/QData/TextAttack/issues/485
101+
# Model parameters has not changed.
99102
#
100103
TEXTATTACK_MODELS = {
101104
#
102105
# LSTMs
103106
#
104-
"lstm-ag-news": "models/classification/lstm/ag-news",
105-
"lstm-imdb": "models/classification/lstm/imdb",
106-
"lstm-mr": "models/classification/lstm/mr",
107-
"lstm-sst2": "models/classification/lstm/sst2",
108-
"lstm-yelp": "models/classification/lstm/yelp",
107+
"lstm-ag-news": "models_v2/classification/lstm/ag-news",
108+
"lstm-imdb": "models_v2/classification/lstm/imdb",
109+
"lstm-mr": "models_v2/classification/lstm/mr",
110+
"lstm-sst2": "models_v2/classification/lstm/sst2",
111+
"lstm-yelp": "models_v2/classification/lstm/yelp",
109112
#
110113
# CNNs
111114
#
112-
"cnn-ag-news": "models/classification/cnn/ag-news",
113-
"cnn-imdb": "models/classification/cnn/imdb",
114-
"cnn-mr": "models/classification/cnn/rotten-tomatoes",
115-
"cnn-sst2": "models/classification/cnn/sst",
116-
"cnn-yelp": "models/classification/cnn/yelp",
115+
"cnn-ag-news": "models_v2/classification/cnn/ag-news",
116+
"cnn-imdb": "models_v2/classification/cnn/imdb",
117+
"cnn-mr": "models_v2/classification/cnn/rotten-tomatoes",
118+
"cnn-sst2": "models_v2/classification/cnn/sst",
119+
"cnn-yelp": "models_v2/classification/cnn/yelp",
117120
#
118121
# T5 for translation
119122
#

textattack/models/helpers/lstm_for_classification.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def from_pretrained(cls, name_or_path):
101101
"""Load trained LSTM model by name or from path.
102102
103103
Args:
104-
name_or_path (str): Name of the model (e.g. "lstm-imdb") or model saved via `save_pretrained`.
104+
name_or_path (:obj:`str`): Name of the model (e.g. "lstm-imdb") or model saved via :meth:`save_pretrained`.
105+
Returns:
106+
:class:`~textattack.models.helpers.LSTMForClassification` model
105107
"""
106108
if name_or_path in TEXTATTACK_MODELS:
107109
# path = utils.download_if_needed(TEXTATTACK_MODELS[name_or_path])
@@ -110,6 +112,7 @@ def from_pretrained(cls, name_or_path):
110112
path = name_or_path
111113

112114
config_path = os.path.join(path, "config.json")
115+
113116
if os.path.exists(config_path):
114117
with open(config_path, "r") as f:
115118
config = json.load(f)

textattack/models/helpers/word_cnn_for_classification.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,17 @@ def from_pretrained(cls, name_or_path):
8787
"""Load trained LSTM model by name or from path.
8888
8989
Args:
90-
name_or_path (str): Name of the model (e.g. "cnn-imdb") or model saved via `save_pretrained`.
90+
name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`.
91+
Returns:
92+
:class:`~textattack.models.helpers.WordCNNForClassification` model
9193
"""
92-
if name_or_path != "cnn" and name_or_path in TEXTATTACK_MODELS:
94+
if name_or_path in TEXTATTACK_MODELS:
9395
path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path])
9496
else:
9597
path = name_or_path
9698

9799
config_path = os.path.join(path, "config.json")
100+
98101
if os.path.exists(config_path):
99102
with open(config_path, "r") as f:
100103
config = json.load(f)

0 commit comments

Comments
 (0)