Skip to content

24G显存出现CUDA error: out of memory正常吗? #165

Description

@sunmouren
from pathlib import Path
import importlib.util
import torch
import torchaudio
from transformers import AutoModel, AutoProcessor
# Disable the broken cuDNN SDPA backend
torch.backends.cuda.enable_cudnn_sdp(False)
# Keep these enabled as fallbacks
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)


pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-v1.5"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32

def resolve_attn_implementation() -> str:
   # Prefer FlashAttention 2 when package + device conditions are met.
   if (
       device == "cuda"
       and importlib.util.find_spec("flash_attn") is not None
       and dtype in {torch.float16, torch.bfloat16}
   ):
       major, _ = torch.cuda.get_device_capability()
       if major >= 8:
           return "flash_attention_2"

   # CUDA fallback: use PyTorch SDPA kernels.
   if device == "cuda":
       return "sdpa"

   # CPU fallback.
   return "eager"


attn_implementation = resolve_attn_implementation()
print(f"[INFO] Using attn_implementation={attn_implementation}")

processor = AutoProcessor.from_pretrained(
   pretrained_model_name_or_path,
   trust_remote_code=True,
)
processor.audio_tokenizer = processor.audio_tokenizer.to(device)

text_1 = "法国天文学家希望与巴黎连接的重要地点之一是乌拉尼堡,即16世纪先驱天文学家第谷·布拉赫的旧天文台。"

# Use audio from ./assets/audio to avoid downloading from the cloud.
ref_audio_1 = "/data/code/python/django/dub_sys/media/878/index-tts/101.wav"


conversations = [
   # Direct TTS (no reference). Language tags are recommended in v1.5.
   [processor.build_user_message(text=text_1)],
  
   # Voice cloning (with reference)
   [processor.build_user_message(text=text_1, reference=[ref_audio_1])],
]

model = AutoModel.from_pretrained(
   pretrained_model_name_or_path,
   trust_remote_code=True,
   attn_implementation=attn_implementation,
   torch_dtype=dtype,
).to(device)
model.eval()

batch_size = 1

save_dir = Path("inference_root")
save_dir.mkdir(exist_ok=True, parents=True)
sample_idx = 0
with torch.no_grad():
   for start in range(0, len(conversations), batch_size):
       batch_conversations = conversations[start : start + batch_size]
       batch = processor(batch_conversations, mode="generation")
       input_ids = batch["input_ids"].to(device)
       attention_mask = batch["attention_mask"].to(device)

       outputs = model.generate(
           input_ids=input_ids,
           attention_mask=attention_mask,
           max_new_tokens=4096,
       )

       for message in processor.decode(outputs):
           audio = message.audio_codes_list[0]
           out_path = save_dir / f"sample{sample_idx}.wav"
           sample_idx += 1
           torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions