Skip to content

Commit 96b9e7d

Browse files
Harden legacy converter checkpoint loading (#2036)
* Harden legacy converter checkpoint loading * Format OpenNMT-py checkpoint load * Add legacy deserialization opt-in for OpenNMT-py * Add legacy deserialization opt-in for Fairseq * Retry CI --------- Co-authored-by: Jordi Mas <jmas@softcatala.org>
1 parent d9b991e commit 96b9e7d

2 files changed

Lines changed: 34 additions & 4 deletions

File tree

python/ctranslate2/converters/fairseq.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
fixed_dictionary: Optional[str] = None,
114114
no_default_special_tokens: bool = False,
115115
user_dir: Optional[str] = None,
116+
unsafe_deserialization: bool = False,
116117
):
117118
"""Initializes the Fairseq converter.
118119
@@ -125,6 +126,8 @@ def __init__(
125126
no_default_special_tokens: Require all special tokens to be provided by the user
126127
(e.g. encoder end token, decoder start token).
127128
user_dir: Path to the user directory containing custom extensions.
129+
unsafe_deserialization: Allow unsafe pickle deserialization when loading
130+
trusted legacy checkpoints.
128131
"""
129132
self._model_path = model_path
130133
self._data_dir = data_dir
@@ -133,6 +136,7 @@ def __init__(
133136
self._target_lang = target_lang
134137
self._no_default_special_tokens = no_default_special_tokens
135138
self._user_dir = user_dir
139+
self._unsafe_deserialization = unsafe_deserialization
136140

137141
def _load(self):
138142
import fairseq
@@ -147,7 +151,9 @@ def _load(self):
147151

148152
with torch.no_grad():
149153
checkpoint = torch.load(
150-
self._model_path, map_location=torch.device("cpu"), weights_only=False
154+
self._model_path,
155+
map_location=torch.device("cpu"),
156+
weights_only=not self._unsafe_deserialization,
151157
)
152158
args = checkpoint["args"] or checkpoint["cfg"]["model"]
153159

@@ -329,6 +335,14 @@ def main():
329335
"including the decoder start token."
330336
),
331337
)
338+
parser.add_argument(
339+
"--unsafe_deserialization",
340+
action="store_true",
341+
help=(
342+
"Allow loading legacy checkpoints with unsafe pickle deserialization. "
343+
"Only enable this option for trusted checkpoints."
344+
),
345+
)
332346
Converter.declare_arguments(parser)
333347
args = parser.parse_args()
334348
converter = FairseqConverter(
@@ -339,6 +353,7 @@ def main():
339353
fixed_dictionary=args.fixed_dictionary,
340354
no_default_special_tokens=args.no_default_special_tokens,
341355
user_dir=args.user_dir,
356+
unsafe_deserialization=args.unsafe_deserialization,
342357
)
343358
converter.convert_from_args(args)
344359

python/ctranslate2/converters/opennmt_py.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,19 +163,24 @@ def get_vocabs(vocab):
163163
class OpenNMTPyConverter(Converter):
164164
"""Converts models generated by OpenNMT-py."""
165165

166-
def __init__(self, model_path: str):
166+
def __init__(self, model_path: str, unsafe_deserialization: bool = False):
167167
"""Initializes the OpenNMT-py converter.
168168
169169
Arguments:
170170
model_path: Path to the OpenNMT-py PyTorch model (.pt file).
171+
unsafe_deserialization: Allow unsafe pickle deserialization when loading
172+
trusted legacy checkpoints.
171173
"""
172174
self._model_path = model_path
175+
self._unsafe_deserialization = unsafe_deserialization
173176

174177
def _load(self):
175178
import torch
176179

177180
checkpoint = torch.load(
178-
self._model_path, map_location="cpu", weights_only=False
181+
self._model_path,
182+
map_location="cpu",
183+
weights_only=not self._unsafe_deserialization,
179184
)
180185

181186
src_vocabs, tgt_vocabs = get_vocabs(checkpoint["vocab"])
@@ -352,9 +357,19 @@ def main():
352357
formatter_class=argparse.ArgumentDefaultsHelpFormatter
353358
)
354359
parser.add_argument("--model_path", required=True, help="Model path.")
360+
parser.add_argument(
361+
"--unsafe_deserialization",
362+
action="store_true",
363+
help=(
364+
"Allow loading legacy checkpoints with unsafe pickle deserialization. "
365+
"Only enable this option for trusted checkpoints."
366+
),
367+
)
355368
Converter.declare_arguments(parser)
356369
args = parser.parse_args()
357-
OpenNMTPyConverter(args.model_path).convert_from_args(args)
370+
OpenNMTPyConverter(
371+
args.model_path, unsafe_deserialization=args.unsafe_deserialization
372+
).convert_from_args(args)
358373

359374

360375
if __name__ == "__main__":

0 commit comments

Comments
 (0)