diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py index e998ab10..a365a3a6 100644 --- a/stable_audio_tools/models/conditioners.py +++ b/stable_audio_tools/models/conditioners.py @@ -179,9 +179,11 @@ def __init__(self, clap_ckpt_path, audio_model_type="HTSAT-base", enable_fusion=True, - project_out: bool = False): + project_out: bool = False, + finetune: bool = False): super().__init__(512, output_dim, project_out=project_out) + self.finetune = finetune device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Suppress logging from transformers