Skip to content

Commit ae8c792

Browse files
feat: add disk-mode weight update flow to gateway (#1237)
* feat: add disk-mode weight update flow to gateway Support disk-based weight transfer for weight-update pairs, including save-to-disk on train workers and load/update on inference workers. Ensure inference pause/continue is wrapped in a context manager so generation resumes even when updates fail. Key changes: - add disk mode fields to connect request and pair metadata - implement disk transfer path with save/update_weights_from_disk and LoRA loading - add controller and gateway tests for disk mode behavior * fix: validate disk-mode gateway connect inputs Fail fast on disk-mode misconfiguration so weight updates do not silently write to local relative paths instead of shared storage. Key changes: - reject empty disk-mode save_path values - require absolute save_path values for disk mode - require lora_name when disk-mode LoRA is enabled - add gateway tests for rejected connect requests --------- Co-authored-by: Wentai Zhang <rchardx@gmail.com>
1 parent 6e69226 commit ae8c792

6 files changed

Lines changed: 757 additions & 14 deletions

File tree

areal/api/io_struct.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def get_versioned_lora_name(lora_name: str, version: int) -> str:
165165

166166
@dataclass
167167
class WeightUpdateMeta:
168-
type: Literal["disk", "xccl"]
168+
type: Literal["disk", "xccl", "awex"]
169169
path: str | None = None
170170
gen_allocation: ModelAllocation | None = None
171171

@@ -267,6 +267,22 @@ def from_fsdp_xccl(
267267
base_model_name=base_model_name,
268268
)
269269

270+
@classmethod
271+
def from_awex(
272+
cls,
273+
use_lora: bool = False,
274+
lora_name: str = "",
275+
lora_int_id: int = 1,
276+
base_model_name: str = "",
277+
):
278+
return cls(
279+
type="awex",
280+
use_lora=use_lora,
281+
lora_name=lora_name,
282+
lora_int_id=lora_int_id,
283+
base_model_name=base_model_name,
284+
)
285+
270286

271287
@dataclass
272288
class HttpRequest:

areal/experimental/weight_update/controller/controller.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def connect(
108108
pair_name: str,
109109
train_worker_urls: list[str],
110110
inference_worker_urls: list[str],
111+
mode: str = "awex",
112+
save_path: str = "",
113+
use_lora: bool = False,
114+
lora_name: str = "",
111115
) -> None:
112116
self._pair_name = pair_name
113117
resp = self._http.post(
@@ -116,11 +120,15 @@ def connect(
116120
"pair_name": pair_name,
117121
"train_worker_urls": train_worker_urls,
118122
"inference_worker_urls": inference_worker_urls,
123+
"mode": mode,
124+
"save_path": save_path,
125+
"use_lora": use_lora,
126+
"lora_name": lora_name,
119127
},
120128
timeout=self.config.request_timeout,
121129
)
122130
resp.raise_for_status()
123-
logger.info("Connected pair '%s'", pair_name)
131+
logger.info("Connected pair '%s' (mode=%s)", pair_name, mode)
124132

125133
def update_weights(self, version: int) -> WeightUpdateResult:
126134
if self._pair_name is None:

areal/experimental/weight_update/gateway/app.py

Lines changed: 182 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import asyncio
5+
import os
56
import socket
67
import time
78
from contextlib import asynccontextmanager
@@ -29,6 +30,10 @@ class ConnectRequest(BaseModel):
2930
pair_name: str
3031
train_worker_urls: list[str]
3132
inference_worker_urls: list[str]
33+
mode: str = "awex" # "awex" or "disk"
34+
save_path: str = ""
35+
use_lora: bool = False
36+
lora_name: str = ""
3237

3338

3439
class UpdateWeightsRequest(BaseModel):
@@ -191,6 +196,39 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse:
191196
train_urls = body.train_worker_urls
192197
inference_urls = body.inference_worker_urls
193198

199+
if body.mode == "disk":
200+
if not body.save_path:
201+
return JSONResponse(
202+
status_code=400,
203+
content={"error": "save_path is required when mode='disk'"},
204+
)
205+
if not os.path.isabs(body.save_path):
206+
return JSONResponse(
207+
status_code=400,
208+
content={
209+
"error": "save_path must be an absolute path when mode='disk'"
210+
},
211+
)
212+
if body.use_lora and not body.lora_name:
213+
return JSONResponse(
214+
status_code=400,
215+
content={"error": "lora_name is required when use_lora=True"},
216+
)
217+
pair_info = PairInfo(
218+
pair_name=pair_name,
219+
train_worker_urls=train_urls,
220+
inference_worker_urls=inference_urls,
221+
mode="disk",
222+
save_path=body.save_path,
223+
use_lora=body.use_lora,
224+
lora_name=body.lora_name,
225+
)
226+
registry.register(pair_info)
227+
logger.info(
228+
"Connected disk pair '%s' (save_path=%s)", pair_name, body.save_path
229+
)
230+
return ConnectResponse(pair_name=pair_name)
231+
194232
session = request.app.state.http_session
195233
init_timeout_s = config.init_timeout_s
196234

@@ -325,6 +363,114 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse:
325363
logger.info("Connected pair '%s'", pair_name)
326364
return ConnectResponse(pair_name=pair_name)
327365

366+
@asynccontextmanager
367+
async def _inference_paused(
368+
session: aiohttp.ClientSession,
369+
inference_urls: list[str],
370+
timeout_s: float,
371+
pair_name: str,
372+
):
373+
await asyncio.gather(
374+
*[
375+
_post(session, f"{url}/pause_generation", timeout_s, json_data={})
376+
for url in inference_urls
377+
]
378+
)
379+
try:
380+
yield
381+
finally:
382+
try:
383+
await asyncio.gather(
384+
*[
385+
_post(
386+
session,
387+
f"{url}/continue_generation",
388+
timeout_s,
389+
json_data={},
390+
)
391+
for url in inference_urls
392+
]
393+
)
394+
except Exception:
395+
logger.warning(
396+
"Failed to resume inference for pair '%s'",
397+
pair_name,
398+
exc_info=True,
399+
)
400+
401+
async def _awex_transfer_weights(
402+
pair_info: PairInfo,
403+
version: int,
404+
session: aiohttp.ClientSession,
405+
timeout_s: float,
406+
) -> None:
407+
await asyncio.gather(
408+
*[
409+
_post(
410+
session,
411+
f"{url}/awex/update_weights",
412+
timeout_s,
413+
json_data={"version": version},
414+
)
415+
for url in pair_info.train_worker_urls + pair_info.inference_worker_urls
416+
]
417+
)
418+
419+
async def _disk_transfer_weights(
420+
pair_info: PairInfo,
421+
version: int,
422+
session: aiohttp.ClientSession,
423+
timeout_s: float,
424+
) -> None:
425+
from areal.api.io_struct import SaveLoadMeta, get_versioned_lora_name
426+
from areal.infra.rpc.serialization import serialize_value
427+
428+
save_path = os.path.join(pair_info.save_path, f"weight_update_v{version}")
429+
430+
save_meta = SaveLoadMeta(path=save_path, weight_format="hf", with_optim=False)
431+
save_payload = {
432+
"args": serialize_value([save_meta]),
433+
"kwargs": serialize_value({}),
434+
}
435+
await asyncio.gather(
436+
*[
437+
_post_json(session, f"{url}/save", timeout_s, json_data=save_payload)
438+
for url in pair_info.train_worker_urls
439+
]
440+
)
441+
442+
if pair_info.use_lora:
443+
lora_name = get_versioned_lora_name(pair_info.lora_name, version)
444+
await asyncio.gather(
445+
*[
446+
_post_json(
447+
session,
448+
f"{url}/load_lora_adapter",
449+
timeout_s,
450+
json_data={
451+
"lora_name": lora_name,
452+
"lora_path": save_path,
453+
},
454+
)
455+
for url in pair_info.inference_worker_urls
456+
]
457+
)
458+
else:
459+
await asyncio.gather(
460+
*[
461+
_post_json(
462+
session,
463+
f"{url}/update_weights_from_disk",
464+
timeout_s,
465+
json_data={
466+
"model_path": save_path,
467+
"abort_all_requests": True,
468+
},
469+
)
470+
for url in pair_info.inference_worker_urls
471+
]
472+
)
473+
328474
@app.post("/update_weights")
329475
async def update_weights(
330476
request: Request, body: UpdateWeightsRequest
@@ -334,27 +480,51 @@ async def update_weights(
334480
pair_info = registry.get_by_name(body.pair_name)
335481
if pair_info is None:
336482
return JSONResponse(
337-
status_code=404, content={"error": f"Pair '{body.pair_name}' not found"}
483+
status_code=404,
484+
content={"error": f"Pair '{body.pair_name}' not found"},
338485
)
339486

340487
session = request.app.state.http_session
341-
update_timeout_s = config.update_timeout_s
342-
488+
timeout_s = config.update_timeout_s
343489
start = time.monotonic()
344-
tasks = [
345-
_post(
490+
491+
try:
492+
async with _inference_paused(
346493
session,
347-
f"{url}/awex/update_weights",
348-
update_timeout_s,
349-
json_data={"version": body.version},
494+
pair_info.inference_worker_urls,
495+
timeout_s,
496+
pair_info.pair_name,
497+
):
498+
if pair_info.mode == "disk":
499+
await _disk_transfer_weights(
500+
pair_info, body.version, session, timeout_s
501+
)
502+
else:
503+
await _awex_transfer_weights(
504+
pair_info, body.version, session, timeout_s
505+
)
506+
except Exception as e:
507+
duration_ms = (time.monotonic() - start) * 1000
508+
logger.error(
509+
"Weight update failed for pair '%s': %s",
510+
pair_info.pair_name,
511+
e,
512+
)
513+
return WeightUpdateResult(
514+
status="error",
515+
version=body.version,
516+
duration_ms=duration_ms,
517+
error=str(e),
350518
)
351-
for url in pair_info.train_worker_urls + pair_info.inference_worker_urls
352-
]
353-
await asyncio.gather(*tasks)
354519

355520
duration_ms = (time.monotonic() - start) * 1000
356521
pair_info.last_version = body.version
357-
522+
logger.info(
523+
"Weight update completed for pair '%s' v%d (%.1fms)",
524+
pair_info.pair_name,
525+
body.version,
526+
duration_ms,
527+
)
358528
return WeightUpdateResult(
359529
status="ok", version=body.version, duration_ms=duration_ms
360530
)

areal/experimental/weight_update/gateway/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ class PairInfo:
5454
master_port: int = 0
5555
last_version: int = 0
5656

57+
# Disk-mode fields (used when mode="disk")
58+
mode: str = "awex" # "awex" or "disk"
59+
save_path: str = ""
60+
use_lora: bool = False
61+
lora_name: str = ""
62+
5763
def __post_init__(self):
5864
if not self.pair_name:
5965
raise ValueError("pair_name must not be empty")
66+
if self.mode not in ("awex", "disk"):
67+
raise ValueError(f"mode must be 'awex' or 'disk', got '{self.mode}'")

0 commit comments

Comments
 (0)