from pathlib import Path
import importlib.util
import torch
import torchaudio
from transformers import AutoModel, AutoProcessor
# Disable the broken cuDNN SDPA backend
torch.backends.cuda.enable_cudnn_sdp(False)
# Keep these enabled as fallbacks
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-v1.5"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
def resolve_attn_implementation() -> str:
# Prefer FlashAttention 2 when package + device conditions are met.
if (
device == "cuda"
and importlib.util.find_spec("flash_attn") is not None
and dtype in {torch.float16, torch.bfloat16}
):
major, _ = torch.cuda.get_device_capability()
if major >= 8:
return "flash_attention_2"
# CUDA fallback: use PyTorch SDPA kernels.
if device == "cuda":
return "sdpa"
# CPU fallback.
return "eager"
attn_implementation = resolve_attn_implementation()
print(f"[INFO] Using attn_implementation={attn_implementation}")
processor = AutoProcessor.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
)
processor.audio_tokenizer = processor.audio_tokenizer.to(device)
text_1 = "法国天文学家希望与巴黎连接的重要地点之一是乌拉尼堡,即16世纪先驱天文学家第谷·布拉赫的旧天文台。"
# Use audio from ./assets/audio to avoid downloading from the cloud.
ref_audio_1 = "/data/code/python/django/dub_sys/media/878/index-tts/101.wav"
conversations = [
# Direct TTS (no reference). Language tags are recommended in v1.5.
[processor.build_user_message(text=text_1)],
# Voice cloning (with reference)
[processor.build_user_message(text=text_1, reference=[ref_audio_1])],
]
model = AutoModel.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
attn_implementation=attn_implementation,
torch_dtype=dtype,
).to(device)
model.eval()
batch_size = 1
save_dir = Path("inference_root")
save_dir.mkdir(exist_ok=True, parents=True)
sample_idx = 0
with torch.no_grad():
for start in range(0, len(conversations), batch_size):
batch_conversations = conversations[start : start + batch_size]
batch = processor(batch_conversations, mode="generation")
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=4096,
)
for message in processor.decode(outputs):
audio = message.audio_codes_list[0]
out_path = save_dir / f"sample{sample_idx}.wav"
sample_idx += 1
torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)