forked from nii-yamagishilab/AntiDeepfake
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
108 lines (91 loc) · 3.29 KB
/
utils.py
File metadata and controls
108 lines (91 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""This script provides a collection of functions used during training pipelines,
including:
* Loading model weights - load_weights()
* Setting random seeds for reproducibility - set_random_seed()
"""
import random
import os
import itertools
import collections
import pickle
import torch
import numpy as np
__author__ = "Wanying Ge, Xin Wang"
__email__ = "gewanying@nii.ac.jp, wangxin@nii.ac.jp"
__copyright__ = "Copyright 2025, National Institute of Informatics"
def load_weights(trg_state, path, func_name_change=lambda x: x):
"""Load trained weights to the state_dict() of a torch.module on CPU
"""
# load to CPU
try:
loaded_state = torch.load(
path,
map_location=lambda storage, loc: storage,
# set to False for loading fariseq pt models: w2v_small, w2v_large, hubert_xl
weights_only=True,
)
except pickle.UnpicklingError as e:
loaded_state = torch.load(
path,
map_location=lambda storage, loc: storage,
# set to False for loading fariseq pt models: w2v_small, w2v_large, hubert_xl
weights_only=False,
)
except:
assert 1==0, "Fail to load {:s}".format(path)
# if it is a fairseq-style checkpoint
if 'model' in loaded_state:
loaded_state = loaded_state['model']
# customized loading patterns (provided by SASV baseline code)
for name, param in loaded_state.items():
origname = name
if name not in trg_state:
name = func_name_change(name)
if name not in trg_state:
print("{:s} is not in the model.".format(origname))
continue
if trg_state[name].size() != loaded_state[origname].size():
print("Wrong para. length: {:s}, model: {:s}, loaded: {:s}".format(
origname, trg_state[name].size(),
loaded_state[origname].size()))
continue
trg_state[name].copy_(param)
return
def set_random_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["SB_GLOBAL_SEED"] = str(seed)
############
# data IO
############
def pickle_dump(data, file_path):
""" pickle_dump(data, file_path)
Dump data into a pickle file
inputs:
data: python object, data to be dumped
file_path: str, path to save the pickle file
"""
try:
os.mkdir(os.path.dirname(file_path))
except OSError:
pass
with open(file_path, 'wb') as file_ptr:
pickle.dump(data, file_ptr)
return
def pickle_load(file_path):
""" data = pickle_load(file_path)
Load data from a pickle dump file
inputs:
file_path: str, path of the pickle file
output:
data: python object
"""
with open(file_path, 'rb') as file_ptr:
data = pickle.load(file_ptr)
return data