From cfbeaab7b6d9c9bdc5a46d999d414998f63c513a Mon Sep 17 00:00:00 2001 From: Theodore Aptekarev Date: Tue, 8 Oct 2024 15:09:41 +0300 Subject: [PATCH] Add support for cpu and mps --- stable_audio_tools/inference/generation.py | 9 +++- stable_audio_tools/inference/sampling.py | 41 +++++++++++------- stable_audio_tools/interface/gradio.py | 9 +++- stable_audio_tools/models/blocks.py | 14 +++++- stable_audio_tools/models/conditioners.py | 42 +++++++++++------- stable_audio_tools/models/pretransforms.py | 16 +++++-- stable_audio_tools/models/transformer.py | 50 +++++++++++++++------- 7 files changed, 127 insertions(+), 54 deletions(-) diff --git a/stable_audio_tools/inference/generation.py b/stable_audio_tools/inference/generation.py index 843ab4b7..c511d505 100644 --- a/stable_audio_tools/inference/generation.py +++ b/stable_audio_tools/inference/generation.py @@ -8,13 +8,20 @@ from .sampling import sample, sample_k, sample_rf from ..data.utils import PadCrop +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + def generate_diffusion_uncond( model, steps: int = 250, batch_size: int = 1, sample_size: int = 2097152, seed: int = -1, - device: str = "cuda", + device: str = device.type, init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, init_noise_level: float = 1.0, return_latents = False, diff --git a/stable_audio_tools/inference/sampling.py b/stable_audio_tools/inference/sampling.py index 2229e508..ddfb39fa 100644 --- a/stable_audio_tools/inference/sampling.py +++ b/stable_audio_tools/inference/sampling.py @@ -4,6 +4,16 @@ import k_diffusion as K +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + +valid_autocast_device_types = {"cuda", "cpu"} +autocast_device_type = device.type if device.type in valid_autocast_device_types else "cpu" + # Define the noise schedule and sampling loop def get_alphas_sigmas(t): """Returns the scaling factors for the clean image (alpha) and for the @@ -58,7 +68,7 @@ def sample(model, x, steps, eta, **extra_args): for i in trange(steps): # Get the model output (v, the predicted velocity) - with torch.cuda.amp.autocast(): + with torch.amp.autocast(autocast_device_type): v = model(x, ts * t[i], **extra_args).float() # Predict the noise and the denoised image @@ -109,16 +119,17 @@ def cond_model_fn(x, sigma, **kwargs): # For variations, set init_data # For inpainting, set both init_data & mask def sample_k( - model_fn, - noise, + model_fn, + noise, init_data=None, mask=None, - steps=100, - sampler_type="dpmpp-2m-sde", - sigma_min=0.5, - sigma_max=50, - rho=1.0, device="cuda", - callback=None, + steps=100, + sampler_type="dpmpp-2m-sde", + sigma_min=0.5, + sigma_max=50, + rho=1.0, + device=device.type, + callback=None, cond_fn=None, **extra_args ): @@ -174,7 +185,7 @@ def inpainting_callback(args): x = noise - with torch.cuda.amp.autocast(): + with torch.amp.autocast(autocast_device_type): if sampler_type == "k-heun": return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) elif sampler_type == "k-lms": @@ -198,13 +209,13 @@ def inpainting_callback(args): # For variations, set init_data # For inpainting, set both init_data & mask def sample_rf( - model_fn, - noise, + model_fn, + noise, init_data=None, - steps=100, + steps=100, sigma_max=1, - device="cuda", - callback=None, + device=device.type, + callback=None, cond_fn=None, **extra_args ): diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index f38468bc..dddfff7b 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -24,7 +24,14 @@ sample_rate = 32000 sample_size = 1920000 -def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + +def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device=device, model_half=False): global model, sample_rate, sample_size if pretrained_name is not None: diff --git a/stable_audio_tools/models/blocks.py b/stable_audio_tools/models/blocks.py index 3c827fd2..0ef5a8b2 100644 --- a/stable_audio_tools/models/blocks.py +++ b/stable_audio_tools/models/blocks.py @@ -5,11 +5,21 @@ from torch import nn from torch.nn import functional as F -from torch.backends.cuda import sdp_kernel from packaging import version from dac.nn.layers import Snake1d +# Determine the device to use +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + +if device.type == 'cuda': + from torch.backends.cuda import sdp_kernel + class ResidualBlock(nn.Module): def __init__(self, main, skip=None): super().__init__() @@ -41,7 +51,7 @@ def __init__(self, c_in, n_head=1, dropout_rate=0.): self.out_proj = nn.Conv1d(c_in, c_in, 1) self.dropout = nn.Dropout(dropout_rate, inplace=True) - self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + self.use_flash = True if device.type == 'cuda' and version.parse(torch.__version__) >= version.parse('2.0.0') else False if not self.use_flash: return diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py index e998ab10..a088c241 100644 --- a/stable_audio_tools/models/conditioners.py +++ b/stable_audio_tools/models/conditioners.py @@ -15,6 +15,16 @@ from torch import nn +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + +valid_autocast_device_types = {"cuda", "cpu"} +autocast_device_type = device.type if device.type in valid_autocast_device_types else "cpu" + class Conditioner(nn.Module): def __init__( self, @@ -71,8 +81,8 @@ def __init__(self, self.embedder = NumberEmbedder(features=output_dim) - def forward(self, floats: tp.List[float], device=None) -> tp.Any: - + def forward(self, floats: tp.List[float], device=device) -> tp.Any: + # Cast the inputs to floats floats = [float(x) for x in floats] @@ -138,9 +148,10 @@ def __init__(self, del self.model.model.audio_branch gc.collect() - torch.cuda.empty_cache() + if device.type == 'cuda': + torch.cuda.empty_cache() - def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): + def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = device): prompt_tokens = self.model.tokenizer(prompts) attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True) prompt_features = self.model.model.text_branch( @@ -151,7 +162,7 @@ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): return prompt_features, attention_mask - def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: + def forward(self, texts: tp.List[str], device: tp.Any = device) -> tp.Any: self.model.to(device) if self.use_text_features: @@ -182,8 +193,6 @@ def __init__(self, project_out: bool = False): super().__init__(512, output_dim, project_out=project_out) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - # Suppress logging from transformers previous_level = logging.root.manager.disable logging.disable(logging.ERROR) @@ -192,8 +201,8 @@ def __init__(self, try: import laion_clap from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict - - model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device=device) if self.finetune: self.model = model @@ -216,9 +225,10 @@ def __init__(self, del self.model.model.text_branch gc.collect() - torch.cuda.empty_cache() + if device.type == 'cuda': + torch.cuda.empty_cache() - def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: + def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = device) -> tp.Any: self.model.to(device) @@ -228,7 +238,7 @@ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple # Convert to mono mono_audios = audios.mean(dim=1) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(autocast_device_type, enabled=False): audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True) audio_embedding = audio_embedding.unsqueeze(1).to(device) @@ -310,12 +320,12 @@ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> t attention_mask = encoded["attention_mask"].to(device).to(torch.bool) self.model.eval() - - with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + + with torch.amp.autocast(autocast_device_type, dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): embeddings = self.model( input_ids=input_ids, attention_mask=attention_mask - )["last_hidden_state"] - + )["last_hidden_state"] + embeddings = self.proj_out(embeddings.float()) embeddings = embeddings * attention_mask.unsqueeze(-1).float() diff --git a/stable_audio_tools/models/pretransforms.py b/stable_audio_tools/models/pretransforms.py index c9942db5..6982708b 100644 --- a/stable_audio_tools/models/pretransforms.py +++ b/stable_audio_tools/models/pretransforms.py @@ -2,6 +2,16 @@ from einops import rearrange from torch import nn +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + +valid_autocast_device_types = {"cuda", "cpu"} +autocast_device_type = device.type if device.type in valid_autocast_device_types else "cpu" + class Pretransform(nn.Module): def __init__(self, enable_grad, io_channels, is_discrete): super().__init__() @@ -250,9 +260,9 @@ def decode(self, z): # return self.model.decode(z) def tokenize(self, x): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(autocast_device_type, enabled=False): return self.model.encode(x.to(torch.float16))[0] - + def decode_tokens(self, tokens): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(autocast_device_type, enabled=False): return self.model.decode(tokens) diff --git a/stable_audio_tools/models/transformer.py b/stable_audio_tools/models/transformer.py index 65965b49..20c982d5 100644 --- a/stable_audio_tools/models/transformer.py +++ b/stable_audio_tools/models/transformer.py @@ -6,9 +6,21 @@ import torch import torch.nn.functional as F from torch import nn, einsum -from torch.cuda.amp import autocast +from torch.amp import autocast from typing import Callable, Literal +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cpu") + +# Ensure device.type is valid for autocast +valid_autocast_device_types = {"cuda", "cpu"} +autocast_device_type = device.type if device.type in valid_autocast_device_types else "cpu" + + try: from flash_attn import flash_attn_func, flash_attn_kvpacked_func except ImportError as e: @@ -123,7 +135,7 @@ def forward_from_seq_len(self, seq_len): t = torch.arange(seq_len, device = device) return self.forward(t) - @autocast(enabled = False) + @autocast(device_type=autocast_device_type, enabled=False) def forward(self, t): device = self.inv_freq.device @@ -148,8 +160,9 @@ def rotate_half(x): x1, x2 = x.unbind(dim = -2) return torch.cat((-x2, x1), dim = -1) -@autocast(enabled = False) -def apply_rotary_pos_emb(t, freqs, scale = 1): + +@autocast(device_type=autocast_device_type, enabled=False) +def apply_rotary_pos_emb(t, freqs, scale=1): out_dtype = t.dtype # cast to float32 if necessary for numerical stability @@ -311,15 +324,17 @@ def __init__( if natten_kernel_size is not None: return - self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + self.use_pt_flash = device.type == "cuda" and version.parse( + torch.__version__ + ) >= version.parse("2.0.0") - self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None + self.use_fa_flash = device.type == "cuda" and flash_attn_func is not None - self.sdp_kwargs = dict( - enable_flash = True, - enable_math = True, - enable_mem_efficient = True - ) + self.sdp_backends = [ + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + ] def flash_attn( self, @@ -378,12 +393,15 @@ def flash_attn( mask[..., 0] = mask[..., 0] | row_is_entirely_masked causal = False - - with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs): + + if device.type == "cuda": + with torch.nn.attention.sdpa_kernel(self.sdp_backends): + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, is_causal=causal + ) + else: out = F.scaled_dot_product_attention( - q, k, v, - attn_mask = mask, - is_causal = causal + q, k, v, attn_mask=mask, is_causal=causal ) # for a row that is entirely masked out, should zero out the output of that row token