-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathAudioNet_utils.py
More file actions
87 lines (77 loc) · 3.09 KB
/
AudioNet_utils.py
File metadata and controls
87 lines (77 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from scipy.io import wavfile
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from AudioNet_model import *
import os
from collections import OrderedDict
def strip_prefix_if_present(state_dict, prefix):
keys = sorted(state_dict.keys())
if not all(key.startswith(prefix) for key in keys):
return state_dict
stripped_state_dict = OrderedDict()
for key, value in state_dict.items():
stripped_state_dict[key.replace(prefix, "")] = value
return stripped_state_dict
def mkdirs(path, remove=False):
if os.path.isdir(path):
if remove:
shutil.rmtree(path)
else:
return
os.makedirs(path)
def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False):
spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
spectro_mag, spectro_phase = librosa.core.magphase(spectro)
spectro_mag = np.expand_dims(spectro_mag, axis=0)
if with_phase:
spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0)
return spectro_mag, spectro_phase
else:
return spectro_mag
def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft):
spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
real = np.expand_dims(np.real(spectro), axis=0)
imag = np.expand_dims(np.imag(spectro), axis=0)
spectro_two_channel = np.concatenate((real, imag), axis=0)
return spectro_two_channel
def batchify(fn, chunk):
"""
Constructs a version of 'fn' that applies to smaller batches
"""
if chunk is None:
return fn
def ret(inputs):
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
return ret
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""
Prepares inputs and applies network 'fn'.
"""
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
if viewdirs is not None:
input_dirs = viewdirs[:,None].expand(inputs.shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = embeddirs_fn(input_dirs_flat)
embedded = torch.cat([embedded, embedded_dirs], -1)
outputs_flat = batchify(fn, netchunk)(embedded)
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs
def create_nerf(args):
"""
Instantiate NeRF's MLP model.
"""
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
output_ch = 2
skips = [4]
model = NeRF(D=args.netdepth, W=args.netwidth,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
model = nn.DataParallel(model).to(device)
grad_vars = list(model.parameters())