-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathNN_weights_saver.py
More file actions
26 lines (21 loc) · 1 KB
/
NN_weights_saver.py
File metadata and controls
26 lines (21 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import torch
import numpy as np
angle_net = torch.load('./good_angle_network')
controller_net = torch.load('controller_network')
ang_w0 = angle_net['0.weight']
ang_b0 = angle_net['0.bias']
ang_w2 = angle_net['2.weight']
ang_b2 = angle_net['2.bias']
np.savetxt('network_weights/ang_w0.txt', ang_w0.numpy(), fmt='%f')
np.savetxt('network_weights/ang_b0.txt', ang_b0.numpy(), fmt='%f')
np.savetxt('network_weights/ang_w2.txt', ang_w2.numpy(), fmt='%f')
np.savetxt('network_weights/ang_b2.txt', ang_b2.numpy(), fmt='%f')
controller_w0 = controller_net['0.weight']
controller_b0 = controller_net['0.bias']
controller_w2 = controller_net['2.weight']
controller_b2 = controller_net['2.bias']
np.savetxt('network_weights/controller_w0.txt', controller_w0.numpy(), fmt='%f')
np.savetxt('network_weights/controller_b0.txt', controller_b0.numpy(), fmt='%f')
np.savetxt('network_weights/controller_w2.txt', controller_w2.numpy(), fmt='%f')
np.savetxt('network_weights/controller_b2.txt', controller_b2.numpy(), fmt='%f')