|
| 1 | +from .wav2vec2 import Wav2Vec2Model |
| 2 | +from .whisper import WhisperLargeV3 |
| 3 | +import comfy.model_management |
| 4 | +import comfy.ops |
| 5 | +import comfy.utils |
| 6 | +import logging |
| 7 | +import torchaudio |
| 8 | + |
| 9 | + |
| 10 | +class AudioEncoderModel(): |
| 11 | + def __init__(self, config): |
| 12 | + self.load_device = comfy.model_management.text_encoder_device() |
| 13 | + offload_device = comfy.model_management.text_encoder_offload_device() |
| 14 | + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) |
| 15 | + model_type = config.pop("model_type") |
| 16 | + model_config = dict(config) |
| 17 | + model_config.update({ |
| 18 | + "dtype": self.dtype, |
| 19 | + "device": offload_device, |
| 20 | + "operations": comfy.ops.manual_cast |
| 21 | + }) |
| 22 | + |
| 23 | + if model_type == "wav2vec2": |
| 24 | + self.model = Wav2Vec2Model(**model_config) |
| 25 | + elif model_type == "whisper3": |
| 26 | + self.model = WhisperLargeV3(**model_config) |
| 27 | + self.model.eval() |
| 28 | + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) |
| 29 | + self.model_sample_rate = 16000 |
| 30 | + |
| 31 | + def load_sd(self, sd): |
| 32 | + return self.model.load_state_dict(sd, strict=False) |
| 33 | + |
| 34 | + def get_sd(self): |
| 35 | + return self.model.state_dict() |
| 36 | + |
| 37 | + def encode_audio(self, audio, sample_rate): |
| 38 | + comfy.model_management.load_model_gpu(self.patcher) |
| 39 | + audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) |
| 40 | + out, all_layers = self.model(audio.to(self.load_device)) |
| 41 | + outputs = {} |
| 42 | + outputs["encoded_audio"] = out |
| 43 | + outputs["encoded_audio_all_layers"] = all_layers |
| 44 | + outputs["audio_samples"] = audio.shape[2] |
| 45 | + return outputs |
| 46 | + |
| 47 | + |
| 48 | +def load_audio_encoder_from_sd(sd, prefix=""): |
| 49 | + sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) |
| 50 | + if "encoder.layer_norm.bias" in sd: #wav2vec2 |
| 51 | + embed_dim = sd["encoder.layer_norm.bias"].shape[0] |
| 52 | + if embed_dim == 1024:# large |
| 53 | + config = { |
| 54 | + "model_type": "wav2vec2", |
| 55 | + "embed_dim": 1024, |
| 56 | + "num_heads": 16, |
| 57 | + "num_layers": 24, |
| 58 | + "conv_norm": True, |
| 59 | + "conv_bias": True, |
| 60 | + "do_normalize": True, |
| 61 | + "do_stable_layer_norm": True |
| 62 | + } |
| 63 | + elif embed_dim == 768: # base |
| 64 | + config = { |
| 65 | + "model_type": "wav2vec2", |
| 66 | + "embed_dim": 768, |
| 67 | + "num_heads": 12, |
| 68 | + "num_layers": 12, |
| 69 | + "conv_norm": False, |
| 70 | + "conv_bias": False, |
| 71 | + "do_normalize": False, # chinese-wav2vec2-base has this False |
| 72 | + "do_stable_layer_norm": False |
| 73 | + } |
| 74 | + else: |
| 75 | + raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) |
| 76 | + elif "model.encoder.embed_positions.weight" in sd: |
| 77 | + sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""}) |
| 78 | + config = { |
| 79 | + "model_type": "whisper3", |
| 80 | + } |
| 81 | + else: |
| 82 | + raise RuntimeError("ERROR: audio encoder not supported.") |
| 83 | + |
| 84 | + audio_encoder = AudioEncoderModel(config) |
| 85 | + m, u = audio_encoder.load_sd(sd) |
| 86 | + if len(m) > 0: |
| 87 | + logging.warning("missing audio encoder: {}".format(m)) |
| 88 | + if len(u) > 0: |
| 89 | + logging.warning("unexpected audio encoder: {}".format(u)) |
| 90 | + |
| 91 | + return audio_encoder |
0 commit comments