forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmimi.py
More file actions
453 lines (376 loc) · 14.3 KB
/
Copy pathmimi.py
File metadata and controls
453 lines (376 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# import argparse
import io
import json
import os
import random
from multiprocessing.connection import Client
import numpy as np
import requests
import sphn
import torch
import torch.nn as nn
import torchaudio
from executorch.backends.qualcomm.quantizer.custom_annotation import (
annotate_mimi_decoder,
)
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.examples.qualcomm.utils import (
build_executorch_binary,
make_output_dir,
make_quantizer,
parse_skip_delegation_node,
setup_common_args_and_variables,
SimpleADB,
)
from huggingface_hub import hf_hub_download
from moshi.models import loaders
from torch.ao.quantization.observer import MinMaxObserver
def seed_all(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def read_mp3_from_url(url):
response = requests.get(url)
response.raise_for_status() # Ensure request is successful
# Convert to a file-like object
audio_stream = io.BytesIO(response.content)
# Load audio using torchaudio
waveform, sample_rate = torchaudio.load(audio_stream, format="mp3")
return waveform.numpy(), sample_rate
def compute_scores(cpu_decode_res: torch.Tensor, htp_decode_res: torch.Tensor):
assert cpu_decode_res.shape == htp_decode_res.shape, "Tensor shapes do not match"
abs_diff = torch.abs(cpu_decode_res - htp_decode_res)
atol = torch.max(abs_diff)
print("Atol: ", atol)
cpu_decode_res = cpu_decode_res.float()
htp_decode_res = htp_decode_res.float()
error = cpu_decode_res - htp_decode_res
original_power = torch.mean(torch.pow(cpu_decode_res, 2))
error_power = torch.mean(torch.pow(error, 2))
sqnr = 10 * torch.log10(original_power / error_power)
print("SQNR: ", sqnr)
def test_decoder_with_emb_input(mimi, args):
class MimiDecode(nn.Module):
def __init__(self, mimi: nn.Module):
super().__init__()
self.mimi_model = mimi
def forward(self, x):
x = x.transpose(1, 2)
x = self.mimi_model.upsample(x)
(emb,) = self.mimi_model.decoder_transformer(x)
emb.transpose(1, 2)
with self.mimi_model._context_for_encoder_decoder:
out = self.mimi_model.decoder(emb)
return out
emb_input = torch.rand(1, 1, 512, device="cpu")
mimi_decode = MimiDecode(mimi).eval()
cpu_res = mimi_decode(emb_input)
pte_filename = "mimi_decoder_emb_qnn"
quantizer = make_quantizer(
quant_dtype=QuantDtype.use_16a8w,
per_channel_conv=True,
per_channel_linear=True,
act_observer=MinMaxObserver,
)
quantizer.add_custom_quant_annotations((annotate_mimi_decoder,))
emb_inputs = [(emb_input,)]
build_executorch_binary(
mimi_decode,
emb_inputs[0],
args.model,
f"{args.artifact}/{pte_filename}",
emb_inputs,
custom_quantizer=quantizer,
quant_dtype=QuantDtype.use_16a8w,
shared_buffer=args.shared_buffer,
)
adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=f"{args.build_folder}",
pte_path=f"{args.artifact}/{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename}",
device_id=args.device,
host_id=args.host,
soc_model=args.model,
shared_buffer=args.shared_buffer,
)
adb.push(inputs=emb_inputs, input_list="input_0_0.raw\n")
adb.execute()
# collect output data
output_data_folder = f"{args.artifact}/outputs"
make_output_dir(output_data_folder)
adb.pull(output_path=args.artifact)
emb_predictions = []
for i in range(len(emb_inputs)):
np_arr = np.fromfile(
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
)
emb_predictions.append(torch.from_numpy(np_arr).view(1, 1, 1920))
print("Emb input test results")
compute_scores(cpu_res, emb_predictions[0])
def mimi_encode(
mimi,
encode_inputs,
encoder_input_list,
pcm_chunk_size,
skip_node_id_set,
skip_node_op_set,
) -> torch.Tensor:
class MimiEncode(nn.Module):
def __init__(self, mimi: nn.Module):
super().__init__()
self.mimi_model = mimi
def forward(self, x):
return self.mimi_model.encode(x)
mimi_encode_model = MimiEncode(mimi)
pte_filename = "mimi_encoder_qnn"
build_executorch_binary(
mimi_encode_model.eval(),
encode_inputs[0],
args.model,
f"{args.artifact}/{pte_filename}",
encode_inputs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
quant_dtype=QuantDtype.use_8a8w,
shared_buffer=args.shared_buffer,
)
adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=f"{args.build_folder}",
pte_path=f"{args.artifact}/{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename}",
device_id=args.device,
host_id=args.host,
soc_model=args.model,
shared_buffer=args.shared_buffer,
)
adb.push(inputs=encode_inputs, input_list=encoder_input_list)
adb.execute()
# collect output data
output_data_folder = f"{args.artifact}/outputs"
make_output_dir(output_data_folder)
adb.pull(output_path=args.artifact)
encoder_predictions = []
# Num chunks should align with args.chunks_per_batch
num_chunks = encode_inputs[0][0].shape[-1] // pcm_chunk_size
for i in range(len(encode_inputs)):
np_arr = np.fromfile(
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.int64
)
encoder_predictions.append(torch.from_numpy(np_arr).view(1, 8, num_chunks))
return encoder_predictions
def mimi_decode(
args, mimi, encode_res_list, pcm_chunk_size, skip_node_id_set, skip_node_op_set
) -> torch.Tensor:
from pathlib import Path
from safetensors.torch import load_model
def _is_safetensors(path: Path | str) -> bool:
return Path(path).suffix in (".safetensors", ".sft", ".sfts")
from moshi.models.compression import MimiModel
from moshi.modules.seanet import SEANetEncoder, SEANetDecoder
from moshi.modules import transformer
from moshi.models.loaders import _seanet_kwargs, _quantizer_kwargs, _transformer_kwargs
from moshi.quantization.vq import SplitResidualVectorQuantizer
class MimiDecode(MimiModel):
def forward(self, x):
return super().decode(x)
encoder = SEANetEncoder(**_seanet_kwargs)
decoder = SEANetDecoder(**_seanet_kwargs)
encoder_transformer = transformer.ProjectedTransformer(
device='cpu', **_transformer_kwargs
)
decoder_transformer = transformer.ProjectedTransformer(
device='cpu', **_transformer_kwargs
)
quantizer = SplitResidualVectorQuantizer(
**_quantizer_kwargs,
)
mimi_decode_model = MimiDecode(
encoder,
decoder,
quantizer,
channels=1,
sample_rate=24000,
frame_rate=12.5,
encoder_frame_rate=24000 / encoder.hop_length,
causal=True,
resample_method="conv",
encoder_transformer=encoder_transformer,
decoder_transformer=decoder_transformer,)
mimi_decode_model.eval()
if _is_safetensors(args.mimi_weight):
load_model(mimi_decode_model, args.mimi_weight, strict=False)
decode_inputs, decode_input_list = [], ""
all_codes = []
sample_input = encode_res_list[..., 0 : 1]
with mimi_decode_model.streaming(1):
#---------------------------------------------Works fine below with nn.Module---------------------------------------------
# for i in range(encode_res_list.shape[-1]):
# codes = encode_res_list[..., i : i + 1]
# pcm = mimi_decode_model(codes)
# all_codes.append(pcm)
#---------------------------------------------SQNR drops to 8.5 after export---------------------------------------------
captured_model = torch.export.export(mimi_decode_model, (sample_input,), strict=False).module()
for i in range(encode_res_list.shape[-1]):
codes = encode_res_list[..., i : i + 1]
pcm = captured_model(codes)
all_codes.append(pcm)
cpu_decode_res = torch.cat(all_codes, dim=-1)
return cpu_decode_res
pte_filename = "mimi_decoder_qnn"
quantizer = make_quantizer(
quant_dtype=QuantDtype.use_16a8w,
per_channel_conv=True,
per_channel_linear=True,
act_observer=MinMaxObserver,
)
quantizer.add_custom_quant_annotations((annotate_mimi_decoder,))
build_executorch_binary(
mimi_decode_model.eval(),
decode_inputs[0],
args.model,
f"{args.artifact}/{pte_filename}",
decode_inputs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
custom_quantizer=quantizer,
quant_dtype=QuantDtype.use_16a8w,
shared_buffer=args.shared_buffer,
)
adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=f"{args.build_folder}",
pte_path=f"{args.artifact}/{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename}",
device_id=args.device,
host_id=args.host,
soc_model=args.model,
shared_buffer=args.shared_buffer,
)
adb.push(inputs=decode_inputs, input_list=decode_input_list)
adb.execute()
# collect output data
output_data_folder = f"{args.artifact}/outputs"
make_output_dir(output_data_folder)
adb.pull(output_path=args.artifact)
decoder_predictions = []
# Num chunks should align with args.chunks_per_batch
num_chunks = decode_inputs[0][0].shape[-1]
shape = num_chunks * pcm_chunk_size
for i in range(len(decode_inputs)):
np_arr = np.fromfile(
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
)
decoder_predictions.append(torch.from_numpy(np_arr).view(1, 1, shape))
htp_decode_res = torch.cat(decoder_predictions, dim=-1)
return htp_decode_res
def export_mimi(mimi, args, max_duration_sec=10.0):
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
os.makedirs(args.artifact, exist_ok=True)
if args.emb_input_test:
test_decoder_with_emb_input(mimi, args)
return
sample_rate = mimi.sample_rate
url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3"
sample_pcm, sample_sr = read_mp3_from_url(url)
sample_rate = mimi.sample_rate
sample_pcm = torch.tensor(sample_pcm, device="cpu")
max_duration_len = int(sample_rate * max_duration_sec)
if sample_pcm.shape[-1] > max_duration_len:
sample_pcm = sample_pcm[..., :max_duration_len]
sample_pcm = sample_pcm[None].to(device="cpu")
encoder_inputs, encoder_input_list = [], ""
# 1920 chunk_size = 0.08sec
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
batch_size = pcm_chunk_size * args.chunks_per_batch
count = 0
for start_idx in range(0, sample_pcm.shape[-1], batch_size):
end_idx = min(sample_pcm.shape[-1], start_idx + batch_size)
chunk = sample_pcm[..., start_idx:end_idx]
encoder_inputs.append((chunk,))
encoder_input_list += f"input_{count}_0.raw\n"
count += 1
print("streaming encoding...")
cpu_encode_res = mimi.encode(sample_pcm)
# htp_encode_res = mimi_encode(
# mimi,
# encoder_inputs,
# encoder_input_list,
# pcm_chunk_size,
# skip_node_id_set,
# skip_node_op_set,
# )
# Leave it here for now, uncomment this to check htp_encoder with cpu_decoder
# htp_res = torch.cat(htp_encode_res, dim=-1)
# cpu_decode_htp_encode = mimi.decode(htp_res)
# sphn.write_wav("cpu_decode_htp_encode.wav", cpu_decode_htp_encode[0, 0].cpu().numpy(), sample_rate)
print("streaming decoding...")
cpu_decode_res = mimi.decode(cpu_encode_res)
# TODO: Enable streaming mode, which is the correct way to execute 1 chunk at a time.
# with mimi.streaming(1):
cpu_streaming_decode_res = mimi_decode(
args, mimi, cpu_encode_res, pcm_chunk_size, skip_node_id_set, skip_node_op_set
)
compute_scores(cpu_decode_res, cpu_streaming_decode_res)
sphn.write_wav(
f"{args.artifact}/cpu_decode_res.wav",
cpu_decode_res[0, 0].cpu().numpy(),
sample_rate,
)
sphn.write_wav(
f"{args.artifact}/cpu_streaming_decode_res.wav",
cpu_streaming_decode_res[0, 0].cpu().numpy(),
sample_rate,
)
def main(args):
seed_all(42424242)
print("loading mimi")
if args.mimi_weight is None:
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
mimi = loaders.get_mimi(args.mimi_weight, "cpu")
print("mimi loaded")
with torch.no_grad():
export_mimi(mimi, args)
if __name__ == "__main__":
parser = setup_common_args_and_variables()
parser.add_argument(
"-a",
"--artifact",
help="path for storing generated artifacts by this example. Default ./mimi",
default="./mimi",
type=str,
)
parser.add_argument(
"--chunks_per_batch",
help="Number of chunks to process per time. Default is 1 chunk per batch, which equals to 0.08 second",
default=1,
type=int,
)
parser.add_argument(
"--emb_input_test",
help="This is just a metrics used to compute accuracy scores, not recommended for general users.",
action="store_true",
default=False,
)
parser.add_argument("--mimi-weight", type=str)
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO)
args = parser.parse_args()
try:
main(args)
except Exception as e:
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({"Error": str(e)}))
else:
raise Exception(e)