-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcalibrate_export_model.py
More file actions
57 lines (44 loc) · 1.94 KB
/
calibrate_export_model.py
File metadata and controls
57 lines (44 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import pickle
import sys
import keras
import numpy as np
from birdmodeling import calib_layer, spectrogram_model
import constants
from birddata import data
from preprocessing import frame_audio, load_audio
def save_competition_classes(species_file, metadata):
species_list = metadata.competition_classes_common_names
# species_list = [x.split('_') for x in species_list]
# species_list = [','.join(x) for x in species_list]
species_list = '\n'.join(species_list)
with open(species_file, 'w') as f:
f.write(species_list)
def do_calibrate(spec_model, valid_df, num_classes):
def do_scores(row):
audio = load_audio(inputs=row.full_path, fixed_length=None)
audio = frame_audio(audio)
score = spec_model(audio)
score = keras.ops.amax(score, axis=0).numpy()
# print(row.full_path, row.name, score.shape, score[row.primary_index])
return score
scores = valid_df.apply(do_scores, axis=1)
y = valid_df['primary_index']
# Would be nice to have a unified model with TF's spec_model an sklearn's cl
# Per ChatGPT, these are to be kept as separate models
cl = calib_layer.Calibrate(num_classes)
cl.adapt(np.stack(scores), y)
return cl
if __name__ == "__main__":
metadata = data.Data(os.path.join(constants.base_dir, constants.metadata_csv))
valid_df = metadata.get_folds()['valid']
spec_model = spectrogram_model.setup_model(metadata.num_classes, weights=None)
spec_model(np.random.random((1, constants.frame_length)))
spec_model.load_weights(sys.argv[1])
cl = do_calibrate(spec_model, valid_df, metadata.num_classes)
# scores = scores.apply(lambda x: cl(x[None])[0])
# valid_df.loc[:, ('scores',)] = scores
# valid_df.to_pickle('scores.pkl')
spec_model.export('./export_model', verbose=False)
pickle.dump(cl, open('export_calib.pkl', 'wb'), protocol=5)
save_competition_classes('competition_classes.txt', metadata)