Skip to content

Commit fe8e642

Browse files
black-elevenyihuiwen
andauthored
add server sync method script (#1036)
Co-authored-by: yihuiwen <yihuiwen@sensetime.com>
1 parent d7ec87f commit fe8e642

7 files changed

Lines changed: 1089 additions & 0 deletions
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
import argparse
2+
import json
3+
import os
4+
import statistics
5+
import time
6+
import uuid
7+
from typing import Any, Dict, List
8+
9+
import requests
10+
11+
try:
12+
import boto3 # pyright: ignore[reportMissingImports]
13+
from botocore.config import Config as BotoConfig # pyright: ignore[reportMissingImports]
14+
except ImportError: # pragma: no cover - runtime dependency check
15+
boto3 = None
16+
BotoConfig = None
17+
18+
19+
def percentile(values: List[float], p: float) -> float:
20+
if not values:
21+
return 0.0
22+
if len(values) == 1:
23+
return values[0]
24+
sorted_values = sorted(values)
25+
rank = (len(sorted_values) - 1) * p
26+
low = int(rank)
27+
high = min(low + 1, len(sorted_values) - 1)
28+
weight = rank - low
29+
return sorted_values[low] * (1 - weight) + sorted_values[high] * weight
30+
31+
32+
def build_s3_client(args: argparse.Namespace):
33+
if boto3 is None or BotoConfig is None:
34+
raise RuntimeError("boto3/botocore is required. Please install boto3 or aioboto3.")
35+
36+
region = args.s3_region or os.getenv("AWS_DEFAULT_REGION", "us-east-1")
37+
access_key = args.s3_access_key or os.getenv("AWS_ACCESS_KEY_ID")
38+
secret_key = args.s3_secret_key or os.getenv("AWS_SECRET_ACCESS_KEY")
39+
session_token = args.s3_session_token or os.getenv("AWS_SESSION_TOKEN")
40+
addressing_style = (args.s3_addressing_style or os.getenv("S3_ADDRESSING_STYLE", "auto")).strip().lower()
41+
signature_version = (args.s3_signature_version or os.getenv("S3_SIGNATURE_VERSION", "s3v4")).strip()
42+
43+
if addressing_style not in {"auto", "path", "virtual"}:
44+
raise ValueError("--s3_addressing_style must be one of: auto, path, virtual")
45+
46+
client_kwargs: Dict[str, Any] = {
47+
"service_name": "s3",
48+
"region_name": region,
49+
"config": BotoConfig(
50+
signature_version=signature_version,
51+
s3={"addressing_style": addressing_style},
52+
),
53+
}
54+
if args.s3_endpoint_url:
55+
client_kwargs["endpoint_url"] = args.s3_endpoint_url
56+
if access_key and secret_key:
57+
client_kwargs["aws_access_key_id"] = access_key
58+
client_kwargs["aws_secret_access_key"] = secret_key
59+
if session_token:
60+
client_kwargs["aws_session_token"] = session_token
61+
return boto3.client(**client_kwargs)
62+
63+
64+
def generate_presigned_pair(s3_client, bucket: str, object_key: str, expires_in: int) -> tuple[str, str]:
65+
put_url = s3_client.generate_presigned_url(
66+
ClientMethod="put_object",
67+
Params={"Bucket": bucket, "Key": object_key},
68+
ExpiresIn=expires_in,
69+
HttpMethod="PUT",
70+
)
71+
get_url = s3_client.generate_presigned_url(
72+
ClientMethod="get_object",
73+
Params={"Bucket": bucket, "Key": object_key},
74+
ExpiresIn=expires_in,
75+
HttpMethod="GET",
76+
)
77+
return put_url, get_url
78+
79+
80+
def build_sync_payload(args: argparse.Namespace, presigned_url: str = "") -> Dict[str, Any]:
81+
payload: Dict[str, Any] = {
82+
"prompt": args.prompt,
83+
"negative_prompt": args.negative_prompt,
84+
"infer_steps": args.infer_steps,
85+
"seed": args.seed,
86+
"aspect_ratio": args.aspect_ratio,
87+
"save_result_path": args.save_result_path,
88+
"use_prompt_enhancer": args.use_prompt_enhancer,
89+
}
90+
if args.target_shape:
91+
payload["target_shape"] = args.target_shape
92+
if presigned_url:
93+
payload["presigned_url"] = presigned_url
94+
return payload
95+
96+
97+
def call_sync(base_url: str, payload: Dict[str, Any], timeout_seconds: int, poll_interval_seconds: float) -> requests.Response:
98+
endpoint = f"{base_url.rstrip('/')}/v1/tasks/image/sync?timeout_seconds={timeout_seconds}&poll_interval_seconds={poll_interval_seconds}"
99+
return requests.post(endpoint, json=payload, timeout=timeout_seconds + 30)
100+
101+
102+
def run_client_upload_flow(
103+
args: argparse.Namespace,
104+
s3_client,
105+
bucket: str,
106+
object_key: str,
107+
) -> Dict[str, float]:
108+
put_url, get_url = generate_presigned_pair(s3_client, bucket, object_key, args.presign_expires)
109+
110+
payload = build_sync_payload(args, presigned_url="")
111+
t0 = time.perf_counter()
112+
response = call_sync(args.url, payload, args.timeout_seconds, args.poll_interval_seconds)
113+
t1 = time.perf_counter()
114+
if response.status_code != 200:
115+
raise RuntimeError(f"[client_upload] sync failed ({response.status_code}): {response.text}")
116+
image_bytes = response.content
117+
t2 = time.perf_counter()
118+
upload_resp = requests.put(put_url, data=image_bytes, timeout=args.upload_timeout_seconds)
119+
t3 = time.perf_counter()
120+
if upload_resp.status_code not in (200, 201, 204):
121+
raise RuntimeError(f"[client_upload] upload failed ({upload_resp.status_code}): {upload_resp.text}")
122+
123+
if args.verify_download:
124+
check_resp = requests.get(get_url, timeout=args.download_timeout_seconds)
125+
if check_resp.status_code != 200:
126+
raise RuntimeError(f"[client_upload] download verify failed ({check_resp.status_code}): {check_resp.text}")
127+
128+
return {
129+
"sync_ms": (t1 - t0) * 1000.0,
130+
"upload_ms": (t3 - t2) * 1000.0,
131+
"total_ms": (t3 - t0) * 1000.0,
132+
"bytes": float(len(image_bytes)),
133+
}
134+
135+
136+
def run_server_upload_flow(
137+
args: argparse.Namespace,
138+
s3_client,
139+
bucket: str,
140+
object_key: str,
141+
) -> Dict[str, float]:
142+
put_url, get_url = generate_presigned_pair(s3_client, bucket, object_key, args.presign_expires)
143+
payload = build_sync_payload(args, presigned_url=put_url)
144+
145+
t0 = time.perf_counter()
146+
response = call_sync(args.url, payload, args.timeout_seconds, args.poll_interval_seconds)
147+
t1 = time.perf_counter()
148+
if response.status_code != 200:
149+
raise RuntimeError(f"[server_upload] sync failed ({response.status_code}): {response.text}")
150+
151+
content_type = response.headers.get("content-type", "")
152+
if "application/json" not in content_type:
153+
raise RuntimeError(f"[server_upload] expected JSON but got content-type={content_type!r}")
154+
body = response.json()
155+
if not body.get("uploaded_to_presigned_url"):
156+
raise RuntimeError(f"[server_upload] upload flag is false, response={json.dumps(body, ensure_ascii=False)}")
157+
158+
if args.verify_download:
159+
check_resp = requests.get(get_url, timeout=args.download_timeout_seconds)
160+
if check_resp.status_code != 200:
161+
raise RuntimeError(f"[server_upload] download verify failed ({check_resp.status_code}): {check_resp.text}")
162+
163+
return {
164+
"sync_total_ms": (t1 - t0) * 1000.0,
165+
}
166+
167+
168+
def print_summary(label: str, values: List[float]) -> None:
169+
print(f"{label}: avg={statistics.mean(values):.2f} ms, p50={percentile(values, 0.5):.2f} ms, p90={percentile(values, 0.9):.2f} ms, min={min(values):.2f} ms, max={max(values):.2f} ms")
170+
171+
172+
def main() -> None:
173+
parser = argparse.ArgumentParser(description="Benchmark sync latency: client upload to S3 vs x2v server upload to S3.")
174+
parser.add_argument("--url", type=str, default="http://127.0.0.1:8000", help="x2v server base url")
175+
parser.add_argument("--runs", type=int, default=5, help="Benchmark rounds")
176+
parser.add_argument("--warmup_runs", type=int, default=0, help="Warmup rounds before measurement")
177+
parser.add_argument("--order", type=str, default="alternate", choices=["alternate", "client_first", "server_first"])
178+
179+
parser.add_argument("--prompt", type=str, required=True, help="Prompt text")
180+
parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt text")
181+
parser.add_argument("--infer_steps", type=int, default=30, help="Inference steps")
182+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
183+
parser.add_argument("--aspect_ratio", type=str, default="16:9", help="Aspect ratio")
184+
parser.add_argument("--target_shape", type=int, nargs="+", default=None, help="Target shape, e.g. 1536 2752")
185+
parser.add_argument("--save_result_path", type=str, default="", help="Server-side save_result_path")
186+
parser.add_argument("--use_prompt_enhancer", action="store_true")
187+
parser.add_argument("--timeout_seconds", type=int, default=600)
188+
parser.add_argument("--poll_interval_seconds", type=float, default=0.5)
189+
190+
parser.add_argument("--s3_endpoint_url", type=str, default="")
191+
parser.add_argument("--s3_region", type=str, default="")
192+
parser.add_argument("--s3_bucket", type=str, default="")
193+
parser.add_argument("--s3_access_key", type=str, default="")
194+
parser.add_argument("--s3_secret_key", type=str, default="")
195+
parser.add_argument("--s3_session_token", type=str, default="")
196+
parser.add_argument("--s3_addressing_style", type=str, default="")
197+
parser.add_argument("--s3_signature_version", type=str, default="")
198+
parser.add_argument("--presign_expires", type=int, default=3600)
199+
parser.add_argument(
200+
"--client_key_prefix",
201+
type=str,
202+
default="",
203+
help="S3 key prefix for client-upload flow, defaults to $S3_BASE_PATH/benchmark/client",
204+
)
205+
parser.add_argument(
206+
"--server_key_prefix",
207+
type=str,
208+
default="",
209+
help="S3 key prefix for server-upload flow, defaults to $S3_BASE_PATH/benchmark/server",
210+
)
211+
parser.add_argument("--upload_timeout_seconds", type=int, default=120)
212+
parser.add_argument("--download_timeout_seconds", type=int, default=120)
213+
parser.add_argument("--verify_download", action="store_true", help="Verify uploaded object with GET presigned URL")
214+
args = parser.parse_args()
215+
216+
if args.runs <= 0:
217+
raise ValueError("--runs must be > 0")
218+
if args.warmup_runs < 0:
219+
raise ValueError("--warmup_runs must be >= 0")
220+
if args.presign_expires <= 0:
221+
raise ValueError("--presign_expires must be > 0")
222+
if args.target_shape is not None and len(args.target_shape) < 2:
223+
raise ValueError("--target_shape must provide at least 2 integers")
224+
225+
bucket = args.s3_bucket or os.getenv("S3_BUCKET", "")
226+
if not bucket:
227+
raise ValueError("Missing S3 bucket. Set --s3_bucket or env S3_BUCKET.")
228+
base_path = os.getenv("S3_BASE_PATH", "lightx2v/sync").strip("/")
229+
client_prefix = (args.client_key_prefix or f"{base_path}/benchmark/client").strip("/")
230+
server_prefix = (args.server_key_prefix or f"{base_path}/benchmark/server").strip("/")
231+
232+
s3_client = build_s3_client(args)
233+
234+
total_rounds = args.warmup_runs + args.runs
235+
client_results: List[Dict[str, float]] = []
236+
server_results: List[Dict[str, float]] = []
237+
238+
print(f"Start benchmark: warmup={args.warmup_runs}, runs={args.runs}, order={args.order}")
239+
for idx in range(total_rounds):
240+
is_warmup = idx < args.warmup_runs
241+
round_no = idx + 1
242+
tag = "warmup" if is_warmup else "measure"
243+
print(f"\nRound {round_no}/{total_rounds} [{tag}]")
244+
245+
client_key = f"{client_prefix}/{uuid.uuid4().hex}.png"
246+
server_key = f"{server_prefix}/{uuid.uuid4().hex}.png"
247+
248+
if args.order == "client_first":
249+
flow_order = ["client", "server"]
250+
elif args.order == "server_first":
251+
flow_order = ["server", "client"]
252+
else:
253+
flow_order = ["client", "server"] if (idx % 2 == 0) else ["server", "client"]
254+
255+
one_round_client = None
256+
one_round_server = None
257+
258+
for flow in flow_order:
259+
if flow == "client":
260+
one_round_client = run_client_upload_flow(args, s3_client, bucket, client_key)
261+
print(
262+
"[client_upload] "
263+
f"sync={one_round_client['sync_ms']:.2f} ms, "
264+
f"upload={one_round_client['upload_ms']:.2f} ms, "
265+
f"total={one_round_client['total_ms']:.2f} ms, "
266+
f"bytes={int(one_round_client['bytes'])}"
267+
)
268+
else:
269+
one_round_server = run_server_upload_flow(args, s3_client, bucket, server_key)
270+
print(f"[server_upload] total={one_round_server['sync_total_ms']:.2f} ms")
271+
272+
if not is_warmup:
273+
assert one_round_client is not None and one_round_server is not None
274+
client_results.append(one_round_client)
275+
server_results.append(one_round_server)
276+
delta_ms = one_round_server["sync_total_ms"] - one_round_client["total_ms"]
277+
print(f"[delta] server_total - client_total = {delta_ms:.2f} ms")
278+
279+
client_sync_list = [x["sync_ms"] for x in client_results]
280+
client_upload_list = [x["upload_ms"] for x in client_results]
281+
client_total_list = [x["total_ms"] for x in client_results]
282+
server_total_list = [x["sync_total_ms"] for x in server_results]
283+
delta_list = [s - c for s, c in zip(server_total_list, client_total_list)]
284+
285+
print("\n=== Benchmark Summary ===")
286+
print_summary("client_upload.sync_ms", client_sync_list)
287+
print_summary("client_upload.upload_ms", client_upload_list)
288+
print_summary("client_upload.total_ms", client_total_list)
289+
print_summary("server_upload.total_ms", server_total_list)
290+
print_summary("delta(server-client).ms", delta_list)
291+
292+
293+
if __name__ == "__main__":
294+
main()

0 commit comments

Comments
 (0)