Skip to content

Commit eeb7964

Browse files
committed
Tests: better automation and filtering for cloud tests
1 parent 7731420 commit eeb7964

5 files changed

Lines changed: 151 additions & 94 deletions

File tree

tests/conftest.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import asyncio
2+
import aiohttp
13
import sys
4+
import psutil
25
import pytest
6+
import os
37
import shutil
48
import subprocess
59
import dotenv
610
from pathlib import Path
11+
from typing import Any
712
from PyQt5.QtCore import QCoreApplication
813

914
sys.path.append(str(Path(__file__).parent.parent))
@@ -15,7 +20,7 @@
1520

1621
def pytest_addoption(parser):
1722
parser.addoption("--test-install", action="store_true")
18-
parser.addoption("--pod-process", action="store_true")
23+
parser.addoption("--cloud", action="store_true")
1924
parser.addoption("--ci", action="store_true")
2025
parser.addoption("--benchmark", action="store_true")
2126

@@ -95,3 +100,110 @@ def local_download_server():
95100

96101
if has_local_cloud:
97102
dotenv.load_dotenv(root_dir / "service" / "web" / ".env.local")
103+
104+
105+
class CloudService:
106+
def __init__(self, loop: QtTestApp, enabled=True):
107+
self.loop = loop
108+
self.dir = root_dir / "service"
109+
self.log_dir = result_dir / "logs"
110+
self.log_dir.mkdir(exist_ok=True)
111+
self.url = os.environ["TEST_SERVICE_URL"]
112+
self.coord_proc: asyncio.subprocess.Process | None = None
113+
self.coord_log = None
114+
self.worker_proc: asyncio.subprocess.Process | None = None
115+
self.worker_task: asyncio.Task | None = None
116+
self.worker_log = None
117+
self.enabled = enabled
118+
119+
async def serve(self, process: asyncio.subprocess.Process, log_file):
120+
try:
121+
async for line in util.ensure(process.stdout):
122+
print(line.decode("utf-8"), end="", file=log_file, flush=True)
123+
except asyncio.CancelledError:
124+
pass
125+
126+
async def launch_coordinator(self):
127+
assert self.coord_proc is None, "Coordinator already running"
128+
self.coord_log = open(self.log_dir / "api.log", "w", encoding="utf-8")
129+
npm = shutil.which("npm")
130+
assert npm is not None, "npm not found in PATH"
131+
args = [npm, "run", "dev"]
132+
self.coord_proc = await asyncio.create_subprocess_exec(
133+
*args,
134+
cwd=self.dir / "api",
135+
stdout=self.coord_log,
136+
stderr=asyncio.subprocess.STDOUT,
137+
)
138+
139+
async def launch_worker(self):
140+
assert self.worker_proc is None, "Worker already running"
141+
self.worker_log = open(self.log_dir / "worker.log", "w", encoding="utf-8")
142+
workerpy = str(self.dir / "pod" / "worker.py")
143+
config = str(self.dir / "pod" / "_var" / "worker.json")
144+
args = ["-u", "-Xutf8", workerpy, config]
145+
self.worker_proc = await asyncio.create_subprocess_exec(
146+
sys.executable,
147+
*args,
148+
cwd=self.dir / "pod",
149+
stdout=subprocess.PIPE,
150+
stderr=subprocess.STDOUT,
151+
)
152+
assert self.worker_proc.stdout is not None
153+
async for line in self.worker_proc.stdout:
154+
text = line.decode("utf-8")
155+
print(text[:80], end="", file=self.worker_log, flush=True)
156+
if "Uvicorn running" in text:
157+
break
158+
159+
self.worker_task = asyncio.create_task(self.serve(self.worker_proc, self.worker_log))
160+
161+
async def start(self):
162+
if not self.enabled or not has_local_cloud:
163+
return
164+
try:
165+
await self.launch_coordinator()
166+
await self.launch_worker()
167+
except Exception as e:
168+
await self.stop()
169+
raise e
170+
171+
async def stop(self):
172+
if self.worker_task:
173+
self.worker_task.cancel()
174+
await self.worker_task
175+
if self.worker_proc:
176+
self.worker_proc.terminate()
177+
await self.worker_proc.wait()
178+
if self.coord_proc:
179+
children = psutil.Process(self.coord_proc.pid).children(recursive=True)
180+
for child in children:
181+
child.terminate()
182+
self.coord_proc.terminate()
183+
await self.coord_proc.wait()
184+
185+
async def create_user(self, username: str) -> dict[str, Any]:
186+
assert self.enabled, "Cloud service is not enabled"
187+
async with aiohttp.ClientSession() as session:
188+
async with session.post(
189+
f"{self.url}/admin/user/create",
190+
json={"name": username},
191+
) as response:
192+
response.raise_for_status()
193+
result = await response.json()
194+
if "error" in result:
195+
raise Exception(result["error"])
196+
return result
197+
198+
def __enter__(self):
199+
self.loop.run(self.start())
200+
return self
201+
202+
def __exit__(self, exc_type, exc, tb):
203+
self.loop.run(self.stop())
204+
205+
206+
@pytest.fixture(scope="session")
207+
def cloud_service(qtapp, pytestconfig):
208+
with CloudService(qtapp, pytestconfig.getoption("--cloud")) as service:
209+
yield service

tests/test_image_transfer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from base64 import b64decode
22
from PIL import Image
33
from datetime import datetime
4-
import os
54
import pytest
65

76
from ai_diffusion.image import ImageCollection, Image as ImageWrapper
87
from ai_diffusion.cloud_client import CloudClient
8+
from tests.conftest import CloudService
99
from .config import root_dir, test_dir
1010

1111
if (root_dir / "service" / "pod" / "lib").exists():
@@ -56,16 +56,18 @@ async def main():
5656
qtapp.run(main())
5757

5858
@pytest.mark.parametrize("mode", ["b64", "transfer"])
59-
def test_receive(qtapp, mode: str):
59+
def test_receive(qtapp, cloud_service: CloudService, mode: str):
60+
if not cloud_service.enabled:
61+
pytest.skip("Cloud service not running")
62+
6063
max_b64_size = max_b64_size_config[mode]
6164
images = [ImageWrapper.load(test_dir / "images" / f) for f in ("cat.webp", "pegonia.webp")]
6265
bytes, offsets = ImageCollection(images).to_bytes()
6366
input = {"image_data": {"bytes": bytes, "offsets": offsets}}
6467

6568
async def main():
66-
url = os.environ["TEST_SERVICE_URL"]
67-
token = os.environ.get("TEST_SERVICE_TOKEN", "")
68-
client = await CloudClient.connect(url, token)
69+
user = await cloud_service.create_user("image-transfer-test")
70+
client = await CloudClient.connect(cloud_service.url, user["token"])
6971
await client.send_images(input, max_inline_size=max_b64_size)
7072

7173
if mode == "transfer":

tests/test_service.py

Lines changed: 11 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
from pathlib import Path
22
import pytest
3-
import subprocess
4-
import os
5-
import sys
6-
import asyncio
7-
import dotenv
83

94
from ai_diffusion.api import WorkflowInput, WorkflowKind, ControlInput, ImageInput, CheckpointInput
105
from ai_diffusion.api import SamplingInput, ConditioningInput, ExtentInput, RegionInput
@@ -13,52 +8,8 @@
138
from ai_diffusion.image import Extent, Image, Bounds
149
from ai_diffusion.resources import ControlMode, Arch
1510
from ai_diffusion.util import ensure
16-
from .conftest import has_local_cloud
17-
from .config import root_dir, test_dir, result_dir
18-
19-
pod_main = root_dir / "service" / "pod" / "pod.py"
20-
run_dir = test_dir / "pod"
21-
22-
23-
@pytest.fixture(scope="module")
24-
def pod_server(qtapp, pytestconfig):
25-
async def serve(process: asyncio.subprocess.Process):
26-
try:
27-
async for line in ensure(process.stdout):
28-
print(line.decode("utf-8"), end="")
29-
except asyncio.CancelledError:
30-
process.terminate()
31-
await process.wait()
32-
33-
async def start():
34-
env = os.environ.copy()
35-
args = ["-u", "-Xutf8", str(pod_main), "--rp_serve_api"]
36-
process = await asyncio.create_subprocess_exec(
37-
sys.executable,
38-
*args,
39-
env=env,
40-
stdout=subprocess.PIPE,
41-
stderr=subprocess.STDOUT,
42-
)
43-
async for line in ensure(process.stdout):
44-
text = line.decode("utf-8")
45-
print(text[:80], end="")
46-
if "Uvicorn running" in text:
47-
break
48-
49-
return process, asyncio.create_task(serve(process))
50-
51-
async def stop(process, task):
52-
process.terminate()
53-
task.cancel()
54-
await process.communicate()
55-
56-
if not pytestconfig.getoption("--pod-process") or pytestconfig.getoption("--ci"):
57-
yield None # For using local docker image or deployed serverless endpoint
58-
else:
59-
process, task = qtapp.run(start())
60-
yield process
61-
qtapp.run(stop(process, task))
11+
from .conftest import CloudService
12+
from .config import test_dir, result_dir
6213

6314

6415
async def receive_images(client: Client, work: WorkflowInput):
@@ -74,16 +25,18 @@ async def receive_images(client: Client, work: WorkflowInput):
7425
assert False, "Connection closed without receiving images"
7526

7627

28+
async def connect_cloud(service: CloudService):
29+
user = await service.create_user("workflow-tester")
30+
return await CloudClient.connect(service.url, user["token"])
31+
32+
7733
@pytest.fixture()
78-
def cloud_client(pytestconfig, qtapp, pod_server):
34+
def cloud_client(pytestconfig, qtapp, cloud_service: CloudService):
7935
if pytestconfig.getoption("--ci"):
8036
pytest.skip("Diffusion is disabled on CI")
81-
if not has_local_cloud:
82-
pytest.skip("Local cloud service not found")
83-
dotenv.load_dotenv(root_dir / "service" / "web" / ".env.local")
84-
url = os.environ["TEST_SERVICE_URL"]
85-
token = os.environ["TEST_SERVICE_TOKEN"]
86-
return qtapp.run(CloudClient.connect(url, token))
37+
if not cloud_service.enabled:
38+
pytest.skip("Cloud service not running")
39+
return qtapp.run(connect_cloud(cloud_service))
8740

8841

8942
def run_and_save(

tests/test_updates.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ai_diffusion.platform_tools import ZipFile
88
from ai_diffusion.updates import AutoUpdate, UpdateState
9-
from .conftest import has_local_cloud
9+
from .conftest import CloudService
1010

1111

1212
class SignalObserver:
@@ -27,15 +27,14 @@ def http_session(service_url: str):
2727
return ClientSession(service_url, headers=headers)
2828

2929

30-
def test_auto_update(qtapp, tmp_path: Path):
31-
if not has_local_cloud:
32-
pytest.skip("No local cloud service found")
33-
qtapp.run(run_auto_update_test(tmp_path))
30+
def test_auto_update(qtapp, cloud_service: CloudService, tmp_path: Path):
31+
if not cloud_service.enabled:
32+
pytest.skip("Cloud service not running")
33+
qtapp.run(run_auto_update_test(cloud_service, tmp_path))
3434

3535

36-
async def run_auto_update_test(tmp_path: Path):
37-
service_url = os.environ["TEST_SERVICE_URL"]
38-
async with http_session(service_url) as session:
36+
async def run_auto_update_test(service: CloudService, tmp_path: Path):
37+
async with http_session(service.url) as session:
3938
last_version = new_version = "666.6.6"
4039

4140
# Get the latest plugin version (set from previous test)
@@ -56,7 +55,7 @@ async def run_auto_update_test(tmp_path: Path):
5655
updater = AutoUpdate(
5756
current_version=last_version,
5857
plugin_dir=install_dir,
59-
api_url=service_url,
58+
api_url=service.url,
6059
)
6160
assert updater.state is UpdateState.unknown
6261

@@ -105,11 +104,10 @@ async def run_auto_update_test(tmp_path: Path):
105104
assert install_test_file.read_text() == "if you're feeling orange, try flying a kite"
106105

107106

108-
async def test_authorization():
109-
if not has_local_cloud:
110-
pytest.skip("No local cloud service found")
111-
service_url = os.environ["TEST_SERVICE_URL"]
112-
async with ClientSession(service_url) as session:
107+
async def test_authorization(cloud_service: CloudService):
108+
if not cloud_service.enabled:
109+
pytest.skip("Cloud service not running")
110+
async with ClientSession(cloud_service.url) as session:
113111
# Version check is public
114112
async with session.get("/plugin/latest?version=1.2.3") as response:
115113
assert response.status == 200

tests/test_workflow.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import itertools
22
import pytest
3-
import dotenv
43
import json
5-
import os
64
from datetime import datetime
75
from pathlib import Path
86
from typing import Any
@@ -25,6 +23,7 @@
2523
from ai_diffusion.workflow import detect_inpaint
2624
from . import config
2725
from .config import root_dir, test_dir, image_dir, result_dir, reference_dir, default_checkpoint
26+
from .conftest import CloudService
2827

2928
service_available = (root_dir / "service" / "web" / ".env.local").exists()
3029
client_params = ["local", "cloud"] if service_available else ["local"]
@@ -38,29 +37,22 @@ async def connect_local():
3837
return client
3938

4039

41-
async def connect_cloud():
42-
dotenv.load_dotenv(root_dir / "service" / "web" / ".env.local")
43-
url = os.environ["TEST_SERVICE_URL"]
44-
token = os.environ.get("TEST_SERVICE_TOKEN", "")
45-
if not token:
46-
client = CloudClient(url)
47-
sign_in = client.sign_in()
48-
auth_url = await anext(sign_in)
49-
print("\nSign-in required:", auth_url)
50-
token = await anext(sign_in)
51-
print("\nToken received:", token, "\n")
52-
return await CloudClient.connect(url, token)
40+
async def connect_cloud(service: CloudService):
41+
user = await service.create_user("workflow-tester")
42+
return await CloudClient.connect(service.url, user["token"])
5343

5444

5545
@pytest.fixture(params=client_params)
56-
def client(pytestconfig, request, qtapp):
46+
def client(pytestconfig, request, qtapp, cloud_service: CloudService):
5747
if pytestconfig.getoption("--ci"):
5848
pytest.skip("Diffusion is disabled on CI")
5949

6050
if request.param == "local":
6151
client = qtapp.run(connect_local())
6252
else:
63-
client = qtapp.run(connect_cloud())
53+
if not cloud_service.enabled:
54+
pytest.skip("Cloud service not running")
55+
client = qtapp.run(connect_cloud(cloud_service))
6456
files.loras.update([File.remote(m) for m in client.models.loras], FileSource.remote)
6557

6658
yield client

0 commit comments

Comments
 (0)