-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathinfer.py
More file actions
68 lines (54 loc) · 1.72 KB
/
infer.py
File metadata and controls
68 lines (54 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from src.audio_io import load_audio
from src.modeling_moss_audio import MossAudioModel
from src.processing_moss_audio import MossAudioProcessor
MODEL_PATH = "weights/MOSS-Audio-4B-Thinking"
AUDIO_PATH = "test/test_kr.mp3"
TEMPERATURE = 1.0
TOP_P = 1.0
TOP_K = 50
def main():
if torch.cuda.is_available():
device_map = "cuda:0"
elif torch.backends.mps.is_available():
device_map = "mps"
else:
device_map = "cpu"
model = MossAudioModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
dtype="auto",
device_map=device_map,
)
model.eval()
processor = MossAudioProcessor.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
enable_time_marker=True,
)
raw_audio = load_audio(AUDIO_PATH, sample_rate=processor.config.mel_sr)
prompt = "Describe this audio."
inputs = processor(text=prompt, audios=[raw_audio], return_tensors="pt")
inputs = inputs.to(model.device)
if inputs.get("audio_data") is not None:
inputs["audio_data"] = inputs["audio_data"].to(model.dtype)
audio_input_mask = inputs["input_ids"] == processor.audio_token_id
inputs["audio_input_mask"] = audio_input_mask
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
num_beams=1,
temperature=TEMPERATURE,
top_p=TOP_P,
top_k=TOP_K,
use_cache=True,
)
input_len = inputs["input_ids"].shape[1]
transcription = processor.decode(
generated_ids[0, input_len:], skip_special_tokens=True
)
print(transcription)
if __name__ == "__main__":
main()