|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import os |
| 4 | +import statistics |
| 5 | +import time |
| 6 | +import uuid |
| 7 | +from typing import Any, Dict, List |
| 8 | + |
| 9 | +import requests |
| 10 | + |
| 11 | +try: |
| 12 | + import boto3 # pyright: ignore[reportMissingImports] |
| 13 | + from botocore.config import Config as BotoConfig # pyright: ignore[reportMissingImports] |
| 14 | +except ImportError: # pragma: no cover - runtime dependency check |
| 15 | + boto3 = None |
| 16 | + BotoConfig = None |
| 17 | + |
| 18 | + |
| 19 | +def percentile(values: List[float], p: float) -> float: |
| 20 | + if not values: |
| 21 | + return 0.0 |
| 22 | + if len(values) == 1: |
| 23 | + return values[0] |
| 24 | + sorted_values = sorted(values) |
| 25 | + rank = (len(sorted_values) - 1) * p |
| 26 | + low = int(rank) |
| 27 | + high = min(low + 1, len(sorted_values) - 1) |
| 28 | + weight = rank - low |
| 29 | + return sorted_values[low] * (1 - weight) + sorted_values[high] * weight |
| 30 | + |
| 31 | + |
| 32 | +def build_s3_client(args: argparse.Namespace): |
| 33 | + if boto3 is None or BotoConfig is None: |
| 34 | + raise RuntimeError("boto3/botocore is required. Please install boto3 or aioboto3.") |
| 35 | + |
| 36 | + region = args.s3_region or os.getenv("AWS_DEFAULT_REGION", "us-east-1") |
| 37 | + access_key = args.s3_access_key or os.getenv("AWS_ACCESS_KEY_ID") |
| 38 | + secret_key = args.s3_secret_key or os.getenv("AWS_SECRET_ACCESS_KEY") |
| 39 | + session_token = args.s3_session_token or os.getenv("AWS_SESSION_TOKEN") |
| 40 | + addressing_style = (args.s3_addressing_style or os.getenv("S3_ADDRESSING_STYLE", "auto")).strip().lower() |
| 41 | + signature_version = (args.s3_signature_version or os.getenv("S3_SIGNATURE_VERSION", "s3v4")).strip() |
| 42 | + |
| 43 | + if addressing_style not in {"auto", "path", "virtual"}: |
| 44 | + raise ValueError("--s3_addressing_style must be one of: auto, path, virtual") |
| 45 | + |
| 46 | + client_kwargs: Dict[str, Any] = { |
| 47 | + "service_name": "s3", |
| 48 | + "region_name": region, |
| 49 | + "config": BotoConfig( |
| 50 | + signature_version=signature_version, |
| 51 | + s3={"addressing_style": addressing_style}, |
| 52 | + ), |
| 53 | + } |
| 54 | + if args.s3_endpoint_url: |
| 55 | + client_kwargs["endpoint_url"] = args.s3_endpoint_url |
| 56 | + if access_key and secret_key: |
| 57 | + client_kwargs["aws_access_key_id"] = access_key |
| 58 | + client_kwargs["aws_secret_access_key"] = secret_key |
| 59 | + if session_token: |
| 60 | + client_kwargs["aws_session_token"] = session_token |
| 61 | + return boto3.client(**client_kwargs) |
| 62 | + |
| 63 | + |
| 64 | +def generate_presigned_pair(s3_client, bucket: str, object_key: str, expires_in: int) -> tuple[str, str]: |
| 65 | + put_url = s3_client.generate_presigned_url( |
| 66 | + ClientMethod="put_object", |
| 67 | + Params={"Bucket": bucket, "Key": object_key}, |
| 68 | + ExpiresIn=expires_in, |
| 69 | + HttpMethod="PUT", |
| 70 | + ) |
| 71 | + get_url = s3_client.generate_presigned_url( |
| 72 | + ClientMethod="get_object", |
| 73 | + Params={"Bucket": bucket, "Key": object_key}, |
| 74 | + ExpiresIn=expires_in, |
| 75 | + HttpMethod="GET", |
| 76 | + ) |
| 77 | + return put_url, get_url |
| 78 | + |
| 79 | + |
| 80 | +def build_sync_payload(args: argparse.Namespace, presigned_url: str = "") -> Dict[str, Any]: |
| 81 | + payload: Dict[str, Any] = { |
| 82 | + "prompt": args.prompt, |
| 83 | + "negative_prompt": args.negative_prompt, |
| 84 | + "infer_steps": args.infer_steps, |
| 85 | + "seed": args.seed, |
| 86 | + "aspect_ratio": args.aspect_ratio, |
| 87 | + "save_result_path": args.save_result_path, |
| 88 | + "use_prompt_enhancer": args.use_prompt_enhancer, |
| 89 | + } |
| 90 | + if args.target_shape: |
| 91 | + payload["target_shape"] = args.target_shape |
| 92 | + if presigned_url: |
| 93 | + payload["presigned_url"] = presigned_url |
| 94 | + return payload |
| 95 | + |
| 96 | + |
| 97 | +def call_sync(base_url: str, payload: Dict[str, Any], timeout_seconds: int, poll_interval_seconds: float) -> requests.Response: |
| 98 | + endpoint = f"{base_url.rstrip('/')}/v1/tasks/image/sync?timeout_seconds={timeout_seconds}&poll_interval_seconds={poll_interval_seconds}" |
| 99 | + return requests.post(endpoint, json=payload, timeout=timeout_seconds + 30) |
| 100 | + |
| 101 | + |
| 102 | +def run_client_upload_flow( |
| 103 | + args: argparse.Namespace, |
| 104 | + s3_client, |
| 105 | + bucket: str, |
| 106 | + object_key: str, |
| 107 | +) -> Dict[str, float]: |
| 108 | + put_url, get_url = generate_presigned_pair(s3_client, bucket, object_key, args.presign_expires) |
| 109 | + |
| 110 | + payload = build_sync_payload(args, presigned_url="") |
| 111 | + t0 = time.perf_counter() |
| 112 | + response = call_sync(args.url, payload, args.timeout_seconds, args.poll_interval_seconds) |
| 113 | + t1 = time.perf_counter() |
| 114 | + if response.status_code != 200: |
| 115 | + raise RuntimeError(f"[client_upload] sync failed ({response.status_code}): {response.text}") |
| 116 | + image_bytes = response.content |
| 117 | + t2 = time.perf_counter() |
| 118 | + upload_resp = requests.put(put_url, data=image_bytes, timeout=args.upload_timeout_seconds) |
| 119 | + t3 = time.perf_counter() |
| 120 | + if upload_resp.status_code not in (200, 201, 204): |
| 121 | + raise RuntimeError(f"[client_upload] upload failed ({upload_resp.status_code}): {upload_resp.text}") |
| 122 | + |
| 123 | + if args.verify_download: |
| 124 | + check_resp = requests.get(get_url, timeout=args.download_timeout_seconds) |
| 125 | + if check_resp.status_code != 200: |
| 126 | + raise RuntimeError(f"[client_upload] download verify failed ({check_resp.status_code}): {check_resp.text}") |
| 127 | + |
| 128 | + return { |
| 129 | + "sync_ms": (t1 - t0) * 1000.0, |
| 130 | + "upload_ms": (t3 - t2) * 1000.0, |
| 131 | + "total_ms": (t3 - t0) * 1000.0, |
| 132 | + "bytes": float(len(image_bytes)), |
| 133 | + } |
| 134 | + |
| 135 | + |
| 136 | +def run_server_upload_flow( |
| 137 | + args: argparse.Namespace, |
| 138 | + s3_client, |
| 139 | + bucket: str, |
| 140 | + object_key: str, |
| 141 | +) -> Dict[str, float]: |
| 142 | + put_url, get_url = generate_presigned_pair(s3_client, bucket, object_key, args.presign_expires) |
| 143 | + payload = build_sync_payload(args, presigned_url=put_url) |
| 144 | + |
| 145 | + t0 = time.perf_counter() |
| 146 | + response = call_sync(args.url, payload, args.timeout_seconds, args.poll_interval_seconds) |
| 147 | + t1 = time.perf_counter() |
| 148 | + if response.status_code != 200: |
| 149 | + raise RuntimeError(f"[server_upload] sync failed ({response.status_code}): {response.text}") |
| 150 | + |
| 151 | + content_type = response.headers.get("content-type", "") |
| 152 | + if "application/json" not in content_type: |
| 153 | + raise RuntimeError(f"[server_upload] expected JSON but got content-type={content_type!r}") |
| 154 | + body = response.json() |
| 155 | + if not body.get("uploaded_to_presigned_url"): |
| 156 | + raise RuntimeError(f"[server_upload] upload flag is false, response={json.dumps(body, ensure_ascii=False)}") |
| 157 | + |
| 158 | + if args.verify_download: |
| 159 | + check_resp = requests.get(get_url, timeout=args.download_timeout_seconds) |
| 160 | + if check_resp.status_code != 200: |
| 161 | + raise RuntimeError(f"[server_upload] download verify failed ({check_resp.status_code}): {check_resp.text}") |
| 162 | + |
| 163 | + return { |
| 164 | + "sync_total_ms": (t1 - t0) * 1000.0, |
| 165 | + } |
| 166 | + |
| 167 | + |
| 168 | +def print_summary(label: str, values: List[float]) -> None: |
| 169 | + print(f"{label}: avg={statistics.mean(values):.2f} ms, p50={percentile(values, 0.5):.2f} ms, p90={percentile(values, 0.9):.2f} ms, min={min(values):.2f} ms, max={max(values):.2f} ms") |
| 170 | + |
| 171 | + |
| 172 | +def main() -> None: |
| 173 | + parser = argparse.ArgumentParser(description="Benchmark sync latency: client upload to S3 vs x2v server upload to S3.") |
| 174 | + parser.add_argument("--url", type=str, default="http://127.0.0.1:8000", help="x2v server base url") |
| 175 | + parser.add_argument("--runs", type=int, default=5, help="Benchmark rounds") |
| 176 | + parser.add_argument("--warmup_runs", type=int, default=0, help="Warmup rounds before measurement") |
| 177 | + parser.add_argument("--order", type=str, default="alternate", choices=["alternate", "client_first", "server_first"]) |
| 178 | + |
| 179 | + parser.add_argument("--prompt", type=str, required=True, help="Prompt text") |
| 180 | + parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt text") |
| 181 | + parser.add_argument("--infer_steps", type=int, default=30, help="Inference steps") |
| 182 | + parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| 183 | + parser.add_argument("--aspect_ratio", type=str, default="16:9", help="Aspect ratio") |
| 184 | + parser.add_argument("--target_shape", type=int, nargs="+", default=None, help="Target shape, e.g. 1536 2752") |
| 185 | + parser.add_argument("--save_result_path", type=str, default="", help="Server-side save_result_path") |
| 186 | + parser.add_argument("--use_prompt_enhancer", action="store_true") |
| 187 | + parser.add_argument("--timeout_seconds", type=int, default=600) |
| 188 | + parser.add_argument("--poll_interval_seconds", type=float, default=0.5) |
| 189 | + |
| 190 | + parser.add_argument("--s3_endpoint_url", type=str, default="") |
| 191 | + parser.add_argument("--s3_region", type=str, default="") |
| 192 | + parser.add_argument("--s3_bucket", type=str, default="") |
| 193 | + parser.add_argument("--s3_access_key", type=str, default="") |
| 194 | + parser.add_argument("--s3_secret_key", type=str, default="") |
| 195 | + parser.add_argument("--s3_session_token", type=str, default="") |
| 196 | + parser.add_argument("--s3_addressing_style", type=str, default="") |
| 197 | + parser.add_argument("--s3_signature_version", type=str, default="") |
| 198 | + parser.add_argument("--presign_expires", type=int, default=3600) |
| 199 | + parser.add_argument( |
| 200 | + "--client_key_prefix", |
| 201 | + type=str, |
| 202 | + default="", |
| 203 | + help="S3 key prefix for client-upload flow, defaults to $S3_BASE_PATH/benchmark/client", |
| 204 | + ) |
| 205 | + parser.add_argument( |
| 206 | + "--server_key_prefix", |
| 207 | + type=str, |
| 208 | + default="", |
| 209 | + help="S3 key prefix for server-upload flow, defaults to $S3_BASE_PATH/benchmark/server", |
| 210 | + ) |
| 211 | + parser.add_argument("--upload_timeout_seconds", type=int, default=120) |
| 212 | + parser.add_argument("--download_timeout_seconds", type=int, default=120) |
| 213 | + parser.add_argument("--verify_download", action="store_true", help="Verify uploaded object with GET presigned URL") |
| 214 | + args = parser.parse_args() |
| 215 | + |
| 216 | + if args.runs <= 0: |
| 217 | + raise ValueError("--runs must be > 0") |
| 218 | + if args.warmup_runs < 0: |
| 219 | + raise ValueError("--warmup_runs must be >= 0") |
| 220 | + if args.presign_expires <= 0: |
| 221 | + raise ValueError("--presign_expires must be > 0") |
| 222 | + if args.target_shape is not None and len(args.target_shape) < 2: |
| 223 | + raise ValueError("--target_shape must provide at least 2 integers") |
| 224 | + |
| 225 | + bucket = args.s3_bucket or os.getenv("S3_BUCKET", "") |
| 226 | + if not bucket: |
| 227 | + raise ValueError("Missing S3 bucket. Set --s3_bucket or env S3_BUCKET.") |
| 228 | + base_path = os.getenv("S3_BASE_PATH", "lightx2v/sync").strip("/") |
| 229 | + client_prefix = (args.client_key_prefix or f"{base_path}/benchmark/client").strip("/") |
| 230 | + server_prefix = (args.server_key_prefix or f"{base_path}/benchmark/server").strip("/") |
| 231 | + |
| 232 | + s3_client = build_s3_client(args) |
| 233 | + |
| 234 | + total_rounds = args.warmup_runs + args.runs |
| 235 | + client_results: List[Dict[str, float]] = [] |
| 236 | + server_results: List[Dict[str, float]] = [] |
| 237 | + |
| 238 | + print(f"Start benchmark: warmup={args.warmup_runs}, runs={args.runs}, order={args.order}") |
| 239 | + for idx in range(total_rounds): |
| 240 | + is_warmup = idx < args.warmup_runs |
| 241 | + round_no = idx + 1 |
| 242 | + tag = "warmup" if is_warmup else "measure" |
| 243 | + print(f"\nRound {round_no}/{total_rounds} [{tag}]") |
| 244 | + |
| 245 | + client_key = f"{client_prefix}/{uuid.uuid4().hex}.png" |
| 246 | + server_key = f"{server_prefix}/{uuid.uuid4().hex}.png" |
| 247 | + |
| 248 | + if args.order == "client_first": |
| 249 | + flow_order = ["client", "server"] |
| 250 | + elif args.order == "server_first": |
| 251 | + flow_order = ["server", "client"] |
| 252 | + else: |
| 253 | + flow_order = ["client", "server"] if (idx % 2 == 0) else ["server", "client"] |
| 254 | + |
| 255 | + one_round_client = None |
| 256 | + one_round_server = None |
| 257 | + |
| 258 | + for flow in flow_order: |
| 259 | + if flow == "client": |
| 260 | + one_round_client = run_client_upload_flow(args, s3_client, bucket, client_key) |
| 261 | + print( |
| 262 | + "[client_upload] " |
| 263 | + f"sync={one_round_client['sync_ms']:.2f} ms, " |
| 264 | + f"upload={one_round_client['upload_ms']:.2f} ms, " |
| 265 | + f"total={one_round_client['total_ms']:.2f} ms, " |
| 266 | + f"bytes={int(one_round_client['bytes'])}" |
| 267 | + ) |
| 268 | + else: |
| 269 | + one_round_server = run_server_upload_flow(args, s3_client, bucket, server_key) |
| 270 | + print(f"[server_upload] total={one_round_server['sync_total_ms']:.2f} ms") |
| 271 | + |
| 272 | + if not is_warmup: |
| 273 | + assert one_round_client is not None and one_round_server is not None |
| 274 | + client_results.append(one_round_client) |
| 275 | + server_results.append(one_round_server) |
| 276 | + delta_ms = one_round_server["sync_total_ms"] - one_round_client["total_ms"] |
| 277 | + print(f"[delta] server_total - client_total = {delta_ms:.2f} ms") |
| 278 | + |
| 279 | + client_sync_list = [x["sync_ms"] for x in client_results] |
| 280 | + client_upload_list = [x["upload_ms"] for x in client_results] |
| 281 | + client_total_list = [x["total_ms"] for x in client_results] |
| 282 | + server_total_list = [x["sync_total_ms"] for x in server_results] |
| 283 | + delta_list = [s - c for s, c in zip(server_total_list, client_total_list)] |
| 284 | + |
| 285 | + print("\n=== Benchmark Summary ===") |
| 286 | + print_summary("client_upload.sync_ms", client_sync_list) |
| 287 | + print_summary("client_upload.upload_ms", client_upload_list) |
| 288 | + print_summary("client_upload.total_ms", client_total_list) |
| 289 | + print_summary("server_upload.total_ms", server_total_list) |
| 290 | + print_summary("delta(server-client).ms", delta_list) |
| 291 | + |
| 292 | + |
| 293 | +if __name__ == "__main__": |
| 294 | + main() |
0 commit comments