-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathTouchNet_utils.py
More file actions
117 lines (105 loc) · 4.47 KB
/
TouchNet_utils.py
File metadata and controls
117 lines (105 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from scipy.io import wavfile
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from TouchNet_model import *
import os
from collections import OrderedDict
from torch._six import container_abcs, string_classes, int_classes
def strip_prefix_if_present(state_dict, prefix):
keys = sorted(state_dict.keys())
if not all(key.startswith(prefix) for key in keys):
return state_dict
stripped_state_dict = OrderedDict()
for key, value in state_dict.items():
stripped_state_dict[key.replace(prefix, "")] = value
return stripped_state_dict
def mkdirs(path, remove=False):
if os.path.isdir(path):
if remove:
shutil.rmtree(path)
else:
return
os.makedirs(path)
def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False):
spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
spectro_mag, spectro_phase = librosa.core.magphase(spectro)
spectro_mag = np.expand_dims(spectro_mag, axis=0)
if with_phase:
spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0)
return spectro_mag, spectro_phase
else:
return spectro_mag
def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft):
spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
real = np.expand_dims(np.real(spectro), axis=0)
imag = np.expand_dims(np.imag(spectro), axis=0)
spectro_two_channel = np.concatenate((real, imag), axis=0)
return spectro_two_channel
def batchify(fn, chunk):
"""
Constructs a version of 'fn' that applies to smaller batches
"""
if chunk is None:
return fn
def ret(inputs):
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
return ret
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""
Prepares inputs and applies network 'fn'.
"""
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
if viewdirs is not None:
input_dirs = viewdirs[:,None].expand(inputs.shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = embeddirs_fn(input_dirs_flat)
embedded = torch.cat([embedded, embedded_dirs], -1)
outputs_flat = batchify(fn, netchunk)(embedded)
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs
def create_nerf(args):
"""
Instantiate NeRF's MLP model.
"""
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
output_ch = 2
skips = [4]
model = NeRF(D=args.netdepth, W=args.netwidth,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
model = nn.DataParallel(model).to(device)
grad_vars = list(model.parameters())
def object_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
#print batch
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
return torch.cat([torch.from_numpy(b) for b in batch], 0) #concatenate even if dimension differs
#return object_collate([torch.from_numpy(b) for b in batch])
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(batch[0], int_classes):
return torch.tensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], container_abcs.Mapping):
return {key: object_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], container_abcs.Sequence):
transposed = zip(*batch)
return [object_collate(samples) for samples in transposed]
raise TypeError((error_msg_fmt.format(type(batch[0]))))