Skip to content

Commit 349ba98

Browse files
beveradbclaude
andauthored
Make audio separator async with Firestore + GCS (#280)
* feat: add Firestore-backed job status store Introduces FirestoreJobStore to replace the in-memory job_status_store dict, enabling any Cloud Run instance to read/write job status for multi-instance GPU scaling. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Add GCS output store for cross-instance file serving Implements GCSOutputStore to upload separation results to GCS so any Cloud Run instance can serve download requests, replacing local disk storage. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Make /separate endpoint async with GPU semaphore and external stores Convert the /separate endpoint from synchronous (await) to fire-and-forget pattern so it returns immediately while separation runs in background. Replace in-memory job_status_store with Firestore reads/writes and local file downloads with GCS, enabling cross-instance status polling and file serving. Add GPU semaphore to serialize concurrent separation requests on a single instance. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * cleanup: remove dead job_status_store dict Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat: reduce POST timeout from 300s to 60s (server is now async) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat: add google-cloud-firestore to Dockerfile for async job store Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: lazy imports + add GCP test deps for CI compatibility Move google.cloud imports inside __init__ methods so modules can be imported without the packages installed. Add google-cloud-firestore and google-cloud-storage to dev dependencies so CI has them available. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: update poetry.lock for GCP test dependencies Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: skip deploy_cloudrun tests when uvicorn not installed The TestLazyInit tests import deploy_cloudrun.py which requires uvicorn and fastapi - server-only dependencies not in the test environment. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent db43a42 commit 349ba98

10 files changed

Lines changed: 1054 additions & 83 deletions

File tree

Dockerfile.cloudrun

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ RUN cd /tmp/audio-separator-src \
6262
"python-multipart>=0.0.6" \
6363
"filetype>=1.2.0" \
6464
"google-cloud-storage>=2.0.0" \
65+
"google-cloud-firestore>=2.0.0" \
6566
&& rm -rf /tmp/audio-separator-src
6667

6768
# Set up CUDA library paths

audio_separator/remote/api_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def separate_audio(
148148
data["custom_output_names"] = json.dumps(custom_output_names)
149149

150150
try:
151-
# Increase timeout for large files (5 minutes)
151+
# Server returns immediately with task_id; 60s is generous for submission
152152
# When using gcs_uri (no file upload), we still need multipart/form-data
153153
# encoding because FastAPI requires it for endpoints with File()/Form() params.
154154
# Passing a dummy empty file field forces requests to use multipart encoding.
@@ -158,7 +158,7 @@ def separate_audio(
158158
f"{self.api_url}/separate",
159159
files=files,
160160
data=data,
161-
timeout=300,
161+
timeout=60,
162162
)
163163
response.raise_for_status()
164164
return response.json()

audio_separator/remote/deploy_cloudrun.py

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,40 @@
4949
MODEL_BUCKET = os.environ.get("MODEL_BUCKET", "")
5050
PORT = int(os.environ.get("PORT", "8080"))
5151

52-
# In-memory job status tracking (one instance handles one job at a time on Cloud Run GPU)
53-
job_status_store: dict[str, dict] = {}
52+
5453

5554
# Track model readiness
5655
models_ready = False
5756

57+
# --- Async job infrastructure ---
58+
gpu_semaphore = threading.Semaphore(1)
59+
60+
OUTPUT_BUCKET = os.environ.get("OUTPUT_BUCKET", "nomadkaraoke-audio-separator-outputs")
61+
GCP_PROJECT = os.environ.get("GCP_PROJECT", "nomadkaraoke")
62+
63+
_job_store = None
64+
_output_store = None
65+
66+
67+
def get_job_store():
68+
"""Get or create the Firestore job store (lazy init)."""
69+
global _job_store
70+
if _job_store is None:
71+
from audio_separator.remote.job_store import FirestoreJobStore
72+
73+
_job_store = FirestoreJobStore(project=GCP_PROJECT)
74+
return _job_store
75+
76+
77+
def get_output_store():
78+
"""Get or create the GCS output store (lazy init)."""
79+
global _output_store
80+
if _output_store is None:
81+
from audio_separator.remote.output_store import GCSOutputStore
82+
83+
_output_store = GCSOutputStore(bucket_name=OUTPUT_BUCKET, project=GCP_PROJECT)
84+
return _output_store
85+
5886

5987
def generate_file_hash(filename: str) -> str:
6088
"""Generate a short, stable hash for a filename to use in download URLs."""
@@ -188,19 +216,26 @@ def separate_audio_sync(
188216

189217
def update_status(status: str, progress: int = 0, error: str = None, files: dict = None):
190218
status_data = {
191-
"task_id": task_id,
192219
"status": status,
193220
"progress": progress,
194-
"original_filename": filename,
195221
"models_used": models_used,
196222
"total_models": len(models) if models else 1,
197223
"current_model_index": 0,
198-
"files": files or {},
199224
}
225+
if files is not None:
226+
status_data["files"] = files
200227
if error:
201228
status_data["error"] = error
202-
job_status_store[task_id] = status_data
203-
229+
try:
230+
get_job_store().update(task_id, status_data)
231+
except Exception as e:
232+
logger.warning(f"[{task_id}] Failed to update Firestore status: {e}")
233+
234+
# Wait for GPU availability
235+
update_status("queued", 0)
236+
logger.info(f"[{task_id}] Waiting for GPU semaphore...")
237+
gpu_semaphore.acquire()
238+
logger.info(f"[{task_id}] GPU semaphore acquired, starting separation")
204239
try:
205240
os.makedirs(f"{STORAGE_DIR}/outputs/{task_id}", exist_ok=True)
206241
output_dir = f"{STORAGE_DIR}/outputs/{task_id}"
@@ -329,6 +364,9 @@ def update_status(status: str, progress: int = 0, error: str = None, files: dict
329364
fname = os.path.basename(f)
330365
all_output_files[generate_file_hash(fname)] = fname
331366

367+
# Upload outputs to GCS for cross-instance access
368+
get_output_store().upload_task_outputs(task_id, output_dir)
369+
332370
update_status("completed", 100, files=all_output_files)
333371
logger.info(f"Separation completed. {len(all_output_files)} output files.")
334372
return {"task_id": task_id, "status": "completed", "files": all_output_files, "models_used": models_used}
@@ -338,13 +376,16 @@ def update_status(status: str, progress: int = 0, error: str = None, files: dict
338376
traceback.print_exc()
339377
update_status("error", 0, error=str(e))
340378

341-
# Clean up on error
379+
return {"task_id": task_id, "status": "error", "error": str(e), "models_used": models_used}
380+
381+
finally:
382+
gpu_semaphore.release()
383+
logger.info(f"[{task_id}] GPU semaphore released")
384+
# Clean up local files (outputs are in GCS now)
342385
output_dir = f"{STORAGE_DIR}/outputs/{task_id}"
343386
if os.path.exists(output_dir):
344387
shutil.rmtree(output_dir, ignore_errors=True)
345388

346-
return {"task_id": task_id, "status": "error", "error": str(e), "models_used": models_used}
347-
348389

349390
# --- FastAPI Application ---
350391

@@ -451,9 +492,10 @@ async def separate_audio(
451492
filename = file.filename
452493

453494
task_id = str(uuid.uuid4())
495+
instance_id = os.environ.get("K_REVISION", "local")
454496

455-
# Set initial status
456-
job_status_store[task_id] = {
497+
# Write initial status to Firestore
498+
get_job_store().set(task_id, {
457499
"task_id": task_id,
458500
"status": "submitted",
459501
"progress": 0,
@@ -462,12 +504,12 @@ async def separate_audio(
462504
"total_models": 1 if preset else (len(models_list) if models_list else 1),
463505
"current_model_index": 0,
464506
"files": {},
465-
}
507+
"instance_id": instance_id,
508+
})
466509

467-
# Run separation in a background thread to not block the event loop
468-
# but keep the request alive (Cloud Run keeps the instance warm)
510+
# Fire-and-forget: run separation in background thread
469511
loop = asyncio.get_event_loop()
470-
await loop.run_in_executor(
512+
loop.run_in_executor(
471513
None,
472514
lambda: separate_audio_sync(
473515
audio_data,
@@ -509,8 +551,15 @@ async def separate_audio(
509551
),
510552
)
511553

512-
# Return the final status (completed or error)
513-
return job_status_store.get(task_id, {"task_id": task_id, "status": "error", "error": "Job lost"})
554+
# Return immediately — client polls /status/{task_id}
555+
return {
556+
"task_id": task_id,
557+
"status": "submitted",
558+
"progress": 0,
559+
"original_filename": filename,
560+
"models_used": [f"preset:{preset}"] if preset else (models_list or ["default"]),
561+
"total_models": 1 if preset else (len(models_list) if models_list else 1),
562+
}
514563

515564
except HTTPException:
516565
raise
@@ -521,8 +570,9 @@ async def separate_audio(
521570
@web_app.get("/status/{task_id}")
522571
async def get_job_status(task_id: str) -> dict:
523572
"""Get the status of a separation job."""
524-
if task_id in job_status_store:
525-
return job_status_store[task_id]
573+
result = get_job_store().get(task_id)
574+
if result:
575+
return result
526576
return {
527577
"task_id": task_id,
528578
"status": "not_found",
@@ -535,32 +585,20 @@ async def get_job_status(task_id: str) -> dict:
535585
async def download_file(task_id: str, file_hash: str) -> Response:
536586
"""Download a separated audio file using its hash identifier."""
537587
try:
538-
# Look up filename from job status
539-
status_data = job_status_store.get(task_id)
588+
status_data = get_job_store().get(task_id)
540589
if not status_data:
541590
raise HTTPException(status_code=404, detail="Task not found")
542591

543592
files_dict = status_data.get("files", {})
544593

545-
# Handle both dict (hash→filename) and list (legacy) formats
546594
actual_filename = None
547595
if isinstance(files_dict, dict):
548596
actual_filename = files_dict.get(file_hash)
549-
elif isinstance(files_dict, list):
550-
for fname in files_dict:
551-
if generate_file_hash(fname) == file_hash:
552-
actual_filename = fname
553-
break
554597

555598
if not actual_filename:
556599
raise HTTPException(status_code=404, detail=f"File with hash {file_hash} not found")
557600

558-
file_path = f"{STORAGE_DIR}/outputs/{task_id}/{actual_filename}"
559-
if not os.path.exists(file_path):
560-
raise HTTPException(status_code=404, detail=f"File not found on disk: {actual_filename}")
561-
562-
with open(file_path, "rb") as f:
563-
file_data = f.read()
601+
file_data = get_output_store().get_file_bytes(task_id, actual_filename)
564602

565603
detected_type = filetype.guess(file_data)
566604
content_type = detected_type.mime if detected_type and detected_type.mime else "application/octet-stream"
@@ -667,9 +705,20 @@ async def root() -> dict:
667705

668706
@web_app.on_event("startup")
669707
async def startup_event():
670-
"""Download models from GCS on startup."""
708+
"""Clean up local storage and download models from GCS on startup."""
671709
os.makedirs(MODEL_DIR, exist_ok=True)
672-
os.makedirs(f"{STORAGE_DIR}/outputs", exist_ok=True)
710+
711+
# Wipe local outputs from previous instance
712+
outputs_dir = f"{STORAGE_DIR}/outputs"
713+
if os.path.exists(outputs_dir):
714+
shutil.rmtree(outputs_dir, ignore_errors=True)
715+
os.makedirs(outputs_dir, exist_ok=True)
716+
717+
# Clean up old Firestore jobs (>1 hour)
718+
try:
719+
get_job_store().cleanup_old_jobs(max_age_seconds=3600)
720+
except Exception as e:
721+
logger.warning(f"Failed to clean up old jobs: {e}")
673722

674723
# Download models in background thread to not block startup probe
675724
thread = threading.Thread(target=download_models_from_gcs, daemon=True)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Firestore-backed job status store for audio separation jobs.
2+
3+
Replaces the in-memory dict so any Cloud Run instance can read/write job status.
4+
"""
5+
import logging
6+
import time
7+
from typing import Optional
8+
9+
logger = logging.getLogger("audio-separator-api")
10+
11+
COLLECTION = "audio_separation_jobs"
12+
13+
14+
class FirestoreJobStore:
15+
"""Job status store backed by Firestore.
16+
17+
Provides dict-like get/set interface for job status documents.
18+
"""
19+
20+
def __init__(self, project: str = "nomadkaraoke"):
21+
from google.cloud import firestore
22+
23+
self._firestore = firestore
24+
self._db = firestore.Client(project=project)
25+
self._collection = self._db.collection(COLLECTION)
26+
27+
def set(self, task_id: str, data: dict) -> None:
28+
"""Create or overwrite a job status document."""
29+
data = {**data, "updated_at": self._firestore.SERVER_TIMESTAMP}
30+
if "created_at" not in data:
31+
data["created_at"] = self._firestore.SERVER_TIMESTAMP
32+
self._collection.document(task_id).set(data)
33+
34+
def get(self, task_id: str) -> Optional[dict]:
35+
"""Get job status. Returns None if not found."""
36+
doc = self._collection.document(task_id).get()
37+
if doc.exists:
38+
return doc.to_dict()
39+
return None
40+
41+
def update(self, task_id: str, fields: dict) -> None:
42+
"""Merge fields into an existing document."""
43+
fields = {**fields, "updated_at": self._firestore.SERVER_TIMESTAMP}
44+
self._collection.document(task_id).update(fields)
45+
46+
def delete(self, task_id: str) -> None:
47+
"""Delete a job status document."""
48+
self._collection.document(task_id).delete()
49+
50+
def __contains__(self, task_id: str) -> bool:
51+
"""Check if a task exists."""
52+
doc = self._collection.document(task_id).get()
53+
return doc.exists
54+
55+
def cleanup_old_jobs(self, max_age_seconds: int = 3600) -> int:
56+
"""Delete completed/errored jobs older than max_age_seconds. Returns count deleted."""
57+
cutoff = time.time() - max_age_seconds
58+
from datetime import datetime, timezone
59+
cutoff_dt = datetime.fromtimestamp(cutoff, tz=timezone.utc)
60+
61+
deleted = 0
62+
query = (
63+
self._collection
64+
.where("status", "in", ["completed", "error"])
65+
.where("updated_at", "<", cutoff_dt)
66+
)
67+
for doc in query.stream():
68+
doc.reference.delete()
69+
deleted += 1
70+
71+
if deleted:
72+
logger.info(f"Cleaned up {deleted} old job(s) from Firestore")
73+
return deleted
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""GCS-backed output file store for audio separation results.
2+
3+
Uploads separation output files to GCS so any Cloud Run instance can serve downloads.
4+
"""
5+
import logging
6+
import os
7+
8+
logger = logging.getLogger("audio-separator-api")
9+
10+
11+
class GCSOutputStore:
12+
"""Manages separation output files in GCS."""
13+
14+
def __init__(self, bucket_name: str = "nomadkaraoke-audio-separator-outputs", project: str = "nomadkaraoke"):
15+
from google.cloud import storage
16+
17+
self._client = storage.Client(project=project)
18+
self._bucket = self._client.bucket(bucket_name)
19+
20+
def upload_task_outputs(self, task_id: str, local_dir: str) -> list[str]:
21+
"""Upload all files in local_dir to GCS under {task_id}/ prefix.
22+
23+
Returns list of uploaded filenames.
24+
"""
25+
uploaded = []
26+
for filename in os.listdir(local_dir):
27+
local_path = os.path.join(local_dir, filename)
28+
if not os.path.isfile(local_path):
29+
continue
30+
gcs_path = f"{task_id}/{filename}"
31+
blob = self._bucket.blob(gcs_path)
32+
blob.upload_from_filename(local_path)
33+
uploaded.append(filename)
34+
logger.info(f"Uploaded {filename} to gs://{self._bucket.name}/{gcs_path}")
35+
return uploaded
36+
37+
def get_file_bytes(self, task_id: str, filename: str) -> bytes:
38+
"""Download file content as bytes (for HTTP responses)."""
39+
gcs_path = f"{task_id}/{filename}"
40+
blob = self._bucket.blob(gcs_path)
41+
return blob.download_as_bytes()
42+
43+
def download_file(self, task_id: str, filename: str, local_path: str) -> str:
44+
"""Download a file from GCS to a local path."""
45+
gcs_path = f"{task_id}/{filename}"
46+
blob = self._bucket.blob(gcs_path)
47+
blob.download_to_filename(local_path)
48+
return local_path
49+
50+
def delete_task_outputs(self, task_id: str) -> int:
51+
"""Delete all output files for a task. Returns count deleted."""
52+
deleted = 0
53+
for blob in self._bucket.list_blobs(prefix=f"{task_id}/"):
54+
blob.delete()
55+
deleted += 1
56+
if deleted:
57+
logger.info(f"Deleted {deleted} output file(s) for task {task_id}")
58+
return deleted

0 commit comments

Comments
 (0)