Skip to content

Commit da87e7f

Browse files
authored
add weights_only=False to torch.load (#1984)
1 parent 89728dd commit da87e7f

141 files changed

Lines changed: 205 additions & 205 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/source/for-dummies/model-export.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the
4141
.. code-block:: python3
4242
4343
>>> import torch
44-
>>> m = torch.load("tdnn/exp/pretrained.pt")
44+
>>> m = torch.load("tdnn/exp/pretrained.pt", weights_only=False)
4545
>>> list(m.keys())
4646
['model']
4747
>>> list(m["model"].keys())

egs/aidatatang_200zh/ASR/local/prepare_lang.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
2929
4. Generate L.pt, in k2 format. It can be loaded by
3030
31-
d = torch.load("L.pt")
31+
d = torch.load("L.pt", weights_only=False)
3232
lexicon = k2.Fsa.from_dict(d)
3333
3434
5. Generate L_disambig.pt, in k2 format.

egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def main():
224224
logging.info("Creating model")
225225
model = get_transducer_model(params)
226226

227-
checkpoint = torch.load(args.checkpoint, map_location="cpu")
227+
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
228228
model.load_state_dict(checkpoint["model"], strict=False)
229229
model.to(device)
230230
model.eval()

egs/aishell/ASR/conformer_ctc/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def main():
503503
else:
504504
H = None
505505
HLG = k2.Fsa.from_dict(
506-
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
506+
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
507507
)
508508
assert HLG.requires_grad is False
509509

egs/aishell/ASR/conformer_ctc/pretrained.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def main():
249249
use_feat_batchnorm=params.use_feat_batchnorm,
250250
)
251251

252-
checkpoint = torch.load(args.checkpoint, map_location="cpu")
252+
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
253253
model.load_state_dict(checkpoint["model"], strict=False)
254254
model.to(device)
255255
model.eval()
@@ -315,7 +315,7 @@ def main():
315315
hyps = [[token_sym_table[i] for i in ids] for ids in token_ids]
316316
elif params.method in ["1best", "attention-decoder"]:
317317
logging.info(f"Loading HLG from {params.HLG}")
318-
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
318+
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
319319
HLG = HLG.to(device)
320320
if not hasattr(HLG, "lm_scores"):
321321
# For whole-lattice-rescoring and attention-decoder

egs/aishell/ASR/conformer_mmi/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def main():
516516
else:
517517
H = None
518518
HLG = k2.Fsa.from_dict(
519-
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
519+
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
520520
)
521521
assert HLG.requires_grad is False
522522

egs/aishell/ASR/local/prepare_lang.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
2929
4. Generate L.pt, in k2 format. It can be loaded by
3030
31-
d = torch.load("L.pt")
31+
d = torch.load("L.pt", weights_only=False)
3232
lexicon = k2.Fsa.from_dict(d)
3333
3434
5. Generate L_disambig.pt, in k2 format.

egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def main():
227227
logging.info("About to create model")
228228
model = get_transducer_model(params)
229229

230-
checkpoint = torch.load(args.checkpoint, map_location="cpu")
230+
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
231231
model.load_state_dict(checkpoint["model"], strict=False)
232232
model.to(device)
233233
model.eval()

egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def main():
228228
logging.info("About to create model")
229229
model = get_transducer_model(params)
230230

231-
checkpoint = torch.load(args.checkpoint, map_location="cpu")
231+
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
232232
model.load_state_dict(checkpoint["model"], strict=False)
233233
model.to(device)
234234
model.eval()

egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def main():
773773
lg_filename = params.lang_dir / "LG.pt"
774774
logging.info(f"Loading {lg_filename}")
775775
decoding_graph = k2.Fsa.from_dict(
776-
torch.load(lg_filename, map_location=device)
776+
torch.load(lg_filename, map_location=device, weights_only=False)
777777
)
778778
decoding_graph.scores *= params.ngram_lm_scale
779779
else:

0 commit comments

Comments
 (0)