-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
70 lines (57 loc) · 2.24 KB
/
inference.py
File metadata and controls
70 lines (57 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import base64
import json
import math
import os
import pickle
import librosa
import numpy as np
import constants
# https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/deploying_tensorflow_serving.html#providing-python-scripts-for-pre-post-processing
tmpfile = '/tmp/abc'
cl = pickle.load(open('/opt/ml/model/export_calib.pkl', 'rb'))
def _clean_audio(x):
"""
Remove silence at the beginning and end of the signal
Arg:
- inputs (array): audio
Returns:
Cleaned audio (array)
"""
largest_val = np.abs(x).max()
edge_case_all_zeros = (largest_val == 0)
if edge_case_all_zeros:
return np.array([])
else:
silence_thresh = largest_val * 1e-5
non_silence = np.where(np.abs(x) > silence_thresh)
first = non_silence[0][0]
last = non_silence[0][-1]
return x[first:last+1]
def input_handler(data, context, max_seconds=15*60):
max_samples = constants.model_sr * max_seconds
is_transform_job = 'TRANSFORM_JOB_ARN' in os.environ
with open(tmpfile, 'wb') as f:
d = data.read()
if not is_transform_job:
d = base64.b64decode(d)
f.write(d)
audio, _ = librosa.load(tmpfile, sr=constants.model_sr, mono=True)
if len(audio) > max_samples:
print(f"Number of audio samples ({len(audio)}) is longer than the maximum ({max_samples}). Trimming audio.")
audio = audio[:max_samples]
audio = _clean_audio(audio)
# https://stackoverflow.com/questions/62459704/np-reshape-with-padding-if-there-are-not-enough-elements
num_samples = len(audio)
num_frames = math.ceil(num_samples/constants.frame_length)
num_pad = num_frames*constants.frame_length - num_samples
audio = np.pad(audio, (0, num_pad))
audio = audio.reshape(num_frames, constants.frame_length)
print("AUDIO prepped:", audio.shape)
return json.dumps({'instances': audio.tolist()}).encode('utf-8')
def output_handler(data, context):
content_json = json.loads(data.content)
pred = np.array(content_json['predictions'])
pred_calib = cl(pred)
print("PRED calibrated:", pred_calib.shape)
pred_calib_json = json.dumps({'predictions': pred_calib.tolist()})
return pred_calib_json.encode('utf-8'), context.accept_header