Skip to content

Commit ad21e97

Browse files
committed
Refactor architecture: Decouple infrastructure, fix RQ usage, and remove obsolete tests
1 parent 4ba9af1 commit ad21e97

5 files changed

Lines changed: 75 additions & 64 deletions

File tree

src/api/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from src.utils.logger import logger
77
from src.db.session import SessionLocal, get_db
88
from src.models.db import IngestionJob
9-
from src.worker import queue
9+
from src.infra.queue import get_queue
1010
import os
1111
from contextlib import asynccontextmanager
1212
from sqlalchemy.orm import Session
@@ -63,6 +63,7 @@ async def ingest_docs(db: Session = Depends(get_db)):
6363
db.refresh(job)
6464

6565
# Enqueue the background task
66+
queue = get_queue()
6667
queue.enqueue("src.jobs.ingestion.process_ingestion", str(job.id))
6768

6869
logger.info(f"Ingestion job {job.id} enqueued.")

src/infra/queue.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import redis
2+
from rq import Queue
3+
from src.config.settings import settings
4+
5+
# We don't initialize connection here directly to avoid side effects on import
6+
# instead we provide a way to get the queue when needed
7+
8+
def get_redis_connection():
9+
return redis.Redis(host=settings.redis_host, port=settings.redis_port)
10+
11+
def get_queue(name="default"):
12+
return Queue(name, connection=get_redis_connection())

src/worker.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
import redis
2-
from rq import Worker, Queue, Connection
3-
from src.config.settings import settings
1+
from rq import Worker
2+
from src.infra.queue import get_queue, get_redis_connection
43
from src.utils.logger import logger
54

6-
redis_conn = redis.Redis(host=settings.redis_host, port=settings.redis_port)
7-
queue = Queue("default", connection=redis_conn)
8-
95
def run_worker():
10-
logger.info(f"Starting worker connected to Redis at {settings.redis_host}:{settings.redis_port}")
11-
with Connection(redis_conn):
12-
worker = Worker([queue])
13-
worker.work()
6+
redis_conn = get_redis_connection()
7+
queue = get_queue()
8+
9+
logger.info(f"Starting worker connected to Redis")
10+
11+
# In recent RQ versions, connection is passed to the Worker constructor
12+
worker = Worker([queue], connection=redis_conn)
13+
worker.work()
1414

1515
if __name__ == "__main__":
1616
run_worker()

tests/test_api.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
11
import pytest
22
from fastapi.testclient import TestClient
33
from unittest.mock import MagicMock, AsyncMock
4+
import uuid
5+
6+
# Mock dependencies before importing app
7+
import src.infra.queue
8+
import src.db.session
9+
10+
# We mock the connection and session at the module level or within fixtures
411
from src.api.app import app
5-
import src.api.app as api_app
612

713
client = TestClient(app)
814

15+
@pytest.fixture
16+
def mock_db(mocker):
17+
mock_session = MagicMock()
18+
mocker.patch("src.db.session.SessionLocal", return_value=mock_session)
19+
# Mock Depends(get_db)
20+
app.dependency_overrides[src.db.session.get_db] = lambda: mock_session
21+
yield mock_session
22+
app.dependency_overrides.clear()
23+
24+
@pytest.fixture
25+
def mock_queue(mocker):
26+
mock_q = MagicMock()
27+
mocker.patch("src.api.app.get_queue", return_value=mock_q)
28+
return mock_q
29+
930
@pytest.fixture
1031
def mock_retrieval_wf(mocker):
1132
mock_wf = MagicMock()
@@ -23,7 +44,10 @@ def test_health_check():
2344
assert response.json()["status"] == "ok"
2445

2546
def test_query_endpoint_uninitialized():
26-
# Before initialization, it should return 503
47+
# Set global retrieval_wf to None for this test
48+
import src.api.app as api_app
49+
api_app.retrieval_wf = None
50+
2751
response = client.post("/query", json={"query": "test"})
2852
assert response.status_code == 503
2953
assert "not initialized" in response.json()["detail"]
@@ -51,3 +75,28 @@ def test_query_endpoint_error(mock_retrieval_wf, mocker):
5175

5276
assert response.status_code == 500
5377
assert "Retrieval failed" in response.json()["detail"]
78+
79+
def test_ingest_endpoint(mock_db, mock_queue):
80+
# Mock job creation
81+
job_id = uuid.uuid4()
82+
mock_db.add.side_effect = lambda job: setattr(job, 'id', job_id)
83+
84+
response = client.post("/ingest")
85+
86+
assert response.status_code == 202
87+
assert response.json()["job_id"] == str(job_id)
88+
mock_queue.enqueue.assert_called_once()
89+
90+
def test_get_job_status(mock_db):
91+
job_id = str(uuid.uuid4())
92+
mock_job = MagicMock()
93+
mock_job.id = job_id
94+
mock_job.status = "COMPLETED"
95+
96+
mock_db.query.return_value.filter.return_value.first.return_value = mock_job
97+
98+
response = client.get(f"/jobs/{job_id}")
99+
100+
assert response.status_code == 200
101+
assert response.json()["status"] == "COMPLETED"
102+
assert response.json()["id"] == job_id

tests/test_splitter.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)