|
| 1 | +#!/usr/bin/env python |
| 2 | +import os |
| 3 | +import logging |
| 4 | +from typing import Optional |
| 5 | + |
| 6 | +import requests |
| 7 | + |
| 8 | + |
| 9 | +class AudioSeparatorAPIClient: |
| 10 | + """Client for interacting with a remotely deployed Audio Separator API.""" |
| 11 | + |
| 12 | + def __init__(self, api_url: str, logger: logging.Logger): |
| 13 | + self.api_url = api_url |
| 14 | + self.logger = logger |
| 15 | + self.session = requests.Session() |
| 16 | + |
| 17 | + def separate_audio(self, file_path: str, model: Optional[str] = None) -> dict: |
| 18 | + """Submit audio separation job (asynchronous processing).""" |
| 19 | + if not os.path.exists(file_path): |
| 20 | + raise FileNotFoundError(f"Audio file not found: {file_path}") |
| 21 | + |
| 22 | + files = {"file": (os.path.basename(file_path), open(file_path, "rb"))} |
| 23 | + data = {} |
| 24 | + |
| 25 | + if model: |
| 26 | + data["model"] = model |
| 27 | + |
| 28 | + try: |
| 29 | + # Increase timeout for large files (5 minutes) |
| 30 | + response = self.session.post(f"{self.api_url}/separate", files=files, data=data, timeout=300) |
| 31 | + response.raise_for_status() |
| 32 | + return response.json() |
| 33 | + except requests.RequestException as e: |
| 34 | + self.logger.error(f"Separation request failed: {e}") |
| 35 | + raise |
| 36 | + finally: |
| 37 | + files["file"][1].close() |
| 38 | + |
| 39 | + def separate_audio_and_wait(self, file_path: str, model: Optional[str] = None, timeout: int = 600, poll_interval: int = 10, download: bool = True, output_dir: Optional[str] = None) -> dict: |
| 40 | + """ |
| 41 | + Submit audio separation job and wait for completion (convenience method). |
| 42 | +
|
| 43 | + This method handles the full workflow: submit job, poll for completion, |
| 44 | + and optionally download the result files. |
| 45 | +
|
| 46 | + Args: |
| 47 | + file_path: Path to the audio file to separate |
| 48 | + model: Model to use for separation (optional) |
| 49 | + timeout: Maximum time to wait for completion in seconds (default: 600) |
| 50 | + poll_interval: How often to check status in seconds (default: 10) |
| 51 | + download: Whether to automatically download result files (default: True) |
| 52 | + output_dir: Directory to save downloaded files (default: current directory) |
| 53 | +
|
| 54 | + Returns: |
| 55 | + dict with keys: |
| 56 | + - task_id: The job task ID |
| 57 | + - status: "completed" or "error" |
| 58 | + - files: List of output filenames |
| 59 | + - downloaded_files: List of local file paths (if download=True) |
| 60 | + - error: Error message (if status="error") |
| 61 | + """ |
| 62 | + import time |
| 63 | + |
| 64 | + # Submit the separation job |
| 65 | + self.logger.info(f"Submitting separation job for '{file_path}'...") |
| 66 | + result = self.separate_audio(file_path, model) |
| 67 | + task_id = result["task_id"] |
| 68 | + self.logger.info(f"Job submitted! Task ID: {task_id}") |
| 69 | + |
| 70 | + # Poll for completion |
| 71 | + self.logger.info("Waiting for separation to complete...") |
| 72 | + start_time = time.time() |
| 73 | + last_progress = -1 |
| 74 | + |
| 75 | + while time.time() - start_time < timeout: |
| 76 | + try: |
| 77 | + status = self.get_job_status(task_id) |
| 78 | + current_status = status.get("status", "unknown") |
| 79 | + |
| 80 | + # Show progress if it changed |
| 81 | + if "progress" in status and status["progress"] != last_progress: |
| 82 | + self.logger.info(f"Progress: {status['progress']}%") |
| 83 | + last_progress = status["progress"] |
| 84 | + |
| 85 | + # Check if completed |
| 86 | + if current_status == "completed": |
| 87 | + self.logger.info("✅ Separation completed!") |
| 88 | + |
| 89 | + result = {"task_id": task_id, "status": "completed", "files": status.get("files", [])} |
| 90 | + |
| 91 | + # Download files if requested |
| 92 | + if download: |
| 93 | + downloaded_files = [] |
| 94 | + self.logger.info(f"📥 Downloading {len(status.get('files', []))} output files...") |
| 95 | + |
| 96 | + for filename in status.get("files", []): |
| 97 | + try: |
| 98 | + if output_dir: |
| 99 | + output_path = f"{output_dir.rstrip('/')}/{filename}" |
| 100 | + else: |
| 101 | + output_path = filename |
| 102 | + |
| 103 | + downloaded_path = self.download_file(task_id, filename, output_path) |
| 104 | + downloaded_files.append(downloaded_path) |
| 105 | + self.logger.info(f" ✅ Downloaded: {downloaded_path}") |
| 106 | + except Exception as e: |
| 107 | + self.logger.error(f" ❌ Failed to download {filename}: {e}") |
| 108 | + |
| 109 | + result["downloaded_files"] = downloaded_files |
| 110 | + self.logger.info(f"🎉 Successfully downloaded {len(downloaded_files)} files!") |
| 111 | + |
| 112 | + return result |
| 113 | + |
| 114 | + elif current_status == "error": |
| 115 | + error_msg = status.get("error", "Unknown error") |
| 116 | + self.logger.error(f"❌ Job failed: {error_msg}") |
| 117 | + return {"task_id": task_id, "status": "error", "error": error_msg, "files": []} |
| 118 | + |
| 119 | + # Wait before next poll |
| 120 | + time.sleep(poll_interval) |
| 121 | + |
| 122 | + except Exception as e: |
| 123 | + self.logger.warning(f"Error polling status: {e}") |
| 124 | + time.sleep(poll_interval) |
| 125 | + |
| 126 | + # Timeout reached |
| 127 | + self.logger.error(f"❌ Job polling timed out after {timeout} seconds") |
| 128 | + return {"task_id": task_id, "status": "timeout", "error": f"Job polling timed out after {timeout} seconds", "files": []} |
| 129 | + |
| 130 | + def get_job_status(self, task_id: str) -> dict: |
| 131 | + """Get job status.""" |
| 132 | + try: |
| 133 | + response = self.session.get(f"{self.api_url}/status/{task_id}", timeout=10) |
| 134 | + response.raise_for_status() |
| 135 | + return response.json() |
| 136 | + except requests.RequestException as e: |
| 137 | + self.logger.error(f"Status request failed: {e}") |
| 138 | + raise |
| 139 | + |
| 140 | + def download_file(self, task_id: str, filename: str, output_path: Optional[str] = None) -> str: |
| 141 | + """Download a file from a completed job.""" |
| 142 | + if output_path is None: |
| 143 | + output_path = filename |
| 144 | + |
| 145 | + try: |
| 146 | + response = self.session.get(f"{self.api_url}/download/{task_id}/{filename}", timeout=60) |
| 147 | + response.raise_for_status() |
| 148 | + |
| 149 | + with open(output_path, "wb") as f: |
| 150 | + f.write(response.content) |
| 151 | + |
| 152 | + return output_path |
| 153 | + except requests.RequestException as e: |
| 154 | + self.logger.error(f"Download failed: {e}") |
| 155 | + raise |
| 156 | + |
| 157 | + def list_models(self, format_type: str = "pretty", filter_by: Optional[str] = None) -> dict: |
| 158 | + """List available models.""" |
| 159 | + try: |
| 160 | + if format_type == "json": |
| 161 | + response = self.session.get(f"{self.api_url}/models-json", timeout=10) |
| 162 | + else: |
| 163 | + url = f"{self.api_url}/models" |
| 164 | + if filter_by: |
| 165 | + url += f"?filter_sort_by={filter_by}" |
| 166 | + response = self.session.get(url, timeout=10) |
| 167 | + |
| 168 | + response.raise_for_status() |
| 169 | + |
| 170 | + if format_type == "json": |
| 171 | + return response.json() |
| 172 | + else: |
| 173 | + return {"text": response.text} |
| 174 | + except requests.RequestException as e: |
| 175 | + self.logger.error(f"Models request failed: {e}") |
| 176 | + raise |
| 177 | + |
| 178 | + def get_server_version(self) -> str: |
| 179 | + """Get the server version.""" |
| 180 | + try: |
| 181 | + response = self.session.get(f"{self.api_url}/health", timeout=10) |
| 182 | + response.raise_for_status() |
| 183 | + health_data = response.json() |
| 184 | + return health_data.get("version", "unknown") |
| 185 | + except requests.RequestException as e: |
| 186 | + self.logger.error(f"Health check request failed: {e}") |
| 187 | + raise |
0 commit comments