22from __future__ import annotations
33
44import asyncio
5+ import os
56import socket
67import time
78from contextlib import asynccontextmanager
@@ -29,6 +30,10 @@ class ConnectRequest(BaseModel):
2930 pair_name : str
3031 train_worker_urls : list [str ]
3132 inference_worker_urls : list [str ]
33+ mode : str = "awex" # "awex" or "disk"
34+ save_path : str = ""
35+ use_lora : bool = False
36+ lora_name : str = ""
3237
3338
3439class UpdateWeightsRequest (BaseModel ):
@@ -191,6 +196,39 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse:
191196 train_urls = body .train_worker_urls
192197 inference_urls = body .inference_worker_urls
193198
199+ if body .mode == "disk" :
200+ if not body .save_path :
201+ return JSONResponse (
202+ status_code = 400 ,
203+ content = {"error" : "save_path is required when mode='disk'" },
204+ )
205+ if not os .path .isabs (body .save_path ):
206+ return JSONResponse (
207+ status_code = 400 ,
208+ content = {
209+ "error" : "save_path must be an absolute path when mode='disk'"
210+ },
211+ )
212+ if body .use_lora and not body .lora_name :
213+ return JSONResponse (
214+ status_code = 400 ,
215+ content = {"error" : "lora_name is required when use_lora=True" },
216+ )
217+ pair_info = PairInfo (
218+ pair_name = pair_name ,
219+ train_worker_urls = train_urls ,
220+ inference_worker_urls = inference_urls ,
221+ mode = "disk" ,
222+ save_path = body .save_path ,
223+ use_lora = body .use_lora ,
224+ lora_name = body .lora_name ,
225+ )
226+ registry .register (pair_info )
227+ logger .info (
228+ "Connected disk pair '%s' (save_path=%s)" , pair_name , body .save_path
229+ )
230+ return ConnectResponse (pair_name = pair_name )
231+
194232 session = request .app .state .http_session
195233 init_timeout_s = config .init_timeout_s
196234
@@ -325,6 +363,114 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse:
325363 logger .info ("Connected pair '%s'" , pair_name )
326364 return ConnectResponse (pair_name = pair_name )
327365
366+ @asynccontextmanager
367+ async def _inference_paused (
368+ session : aiohttp .ClientSession ,
369+ inference_urls : list [str ],
370+ timeout_s : float ,
371+ pair_name : str ,
372+ ):
373+ await asyncio .gather (
374+ * [
375+ _post (session , f"{ url } /pause_generation" , timeout_s , json_data = {})
376+ for url in inference_urls
377+ ]
378+ )
379+ try :
380+ yield
381+ finally :
382+ try :
383+ await asyncio .gather (
384+ * [
385+ _post (
386+ session ,
387+ f"{ url } /continue_generation" ,
388+ timeout_s ,
389+ json_data = {},
390+ )
391+ for url in inference_urls
392+ ]
393+ )
394+ except Exception :
395+ logger .warning (
396+ "Failed to resume inference for pair '%s'" ,
397+ pair_name ,
398+ exc_info = True ,
399+ )
400+
401+ async def _awex_transfer_weights (
402+ pair_info : PairInfo ,
403+ version : int ,
404+ session : aiohttp .ClientSession ,
405+ timeout_s : float ,
406+ ) -> None :
407+ await asyncio .gather (
408+ * [
409+ _post (
410+ session ,
411+ f"{ url } /awex/update_weights" ,
412+ timeout_s ,
413+ json_data = {"version" : version },
414+ )
415+ for url in pair_info .train_worker_urls + pair_info .inference_worker_urls
416+ ]
417+ )
418+
419+ async def _disk_transfer_weights (
420+ pair_info : PairInfo ,
421+ version : int ,
422+ session : aiohttp .ClientSession ,
423+ timeout_s : float ,
424+ ) -> None :
425+ from areal .api .io_struct import SaveLoadMeta , get_versioned_lora_name
426+ from areal .infra .rpc .serialization import serialize_value
427+
428+ save_path = os .path .join (pair_info .save_path , f"weight_update_v{ version } " )
429+
430+ save_meta = SaveLoadMeta (path = save_path , weight_format = "hf" , with_optim = False )
431+ save_payload = {
432+ "args" : serialize_value ([save_meta ]),
433+ "kwargs" : serialize_value ({}),
434+ }
435+ await asyncio .gather (
436+ * [
437+ _post_json (session , f"{ url } /save" , timeout_s , json_data = save_payload )
438+ for url in pair_info .train_worker_urls
439+ ]
440+ )
441+
442+ if pair_info .use_lora :
443+ lora_name = get_versioned_lora_name (pair_info .lora_name , version )
444+ await asyncio .gather (
445+ * [
446+ _post_json (
447+ session ,
448+ f"{ url } /load_lora_adapter" ,
449+ timeout_s ,
450+ json_data = {
451+ "lora_name" : lora_name ,
452+ "lora_path" : save_path ,
453+ },
454+ )
455+ for url in pair_info .inference_worker_urls
456+ ]
457+ )
458+ else :
459+ await asyncio .gather (
460+ * [
461+ _post_json (
462+ session ,
463+ f"{ url } /update_weights_from_disk" ,
464+ timeout_s ,
465+ json_data = {
466+ "model_path" : save_path ,
467+ "abort_all_requests" : True ,
468+ },
469+ )
470+ for url in pair_info .inference_worker_urls
471+ ]
472+ )
473+
328474 @app .post ("/update_weights" )
329475 async def update_weights (
330476 request : Request , body : UpdateWeightsRequest
@@ -334,27 +480,51 @@ async def update_weights(
334480 pair_info = registry .get_by_name (body .pair_name )
335481 if pair_info is None :
336482 return JSONResponse (
337- status_code = 404 , content = {"error" : f"Pair '{ body .pair_name } ' not found" }
483+ status_code = 404 ,
484+ content = {"error" : f"Pair '{ body .pair_name } ' not found" },
338485 )
339486
340487 session = request .app .state .http_session
341- update_timeout_s = config .update_timeout_s
342-
488+ timeout_s = config .update_timeout_s
343489 start = time .monotonic ()
344- tasks = [
345- _post (
490+
491+ try :
492+ async with _inference_paused (
346493 session ,
347- f"{ url } /awex/update_weights" ,
348- update_timeout_s ,
349- json_data = {"version" : body .version },
494+ pair_info .inference_worker_urls ,
495+ timeout_s ,
496+ pair_info .pair_name ,
497+ ):
498+ if pair_info .mode == "disk" :
499+ await _disk_transfer_weights (
500+ pair_info , body .version , session , timeout_s
501+ )
502+ else :
503+ await _awex_transfer_weights (
504+ pair_info , body .version , session , timeout_s
505+ )
506+ except Exception as e :
507+ duration_ms = (time .monotonic () - start ) * 1000
508+ logger .error (
509+ "Weight update failed for pair '%s': %s" ,
510+ pair_info .pair_name ,
511+ e ,
512+ )
513+ return WeightUpdateResult (
514+ status = "error" ,
515+ version = body .version ,
516+ duration_ms = duration_ms ,
517+ error = str (e ),
350518 )
351- for url in pair_info .train_worker_urls + pair_info .inference_worker_urls
352- ]
353- await asyncio .gather (* tasks )
354519
355520 duration_ms = (time .monotonic () - start ) * 1000
356521 pair_info .last_version = body .version
357-
522+ logger .info (
523+ "Weight update completed for pair '%s' v%d (%.1fms)" ,
524+ pair_info .pair_name ,
525+ body .version ,
526+ duration_ms ,
527+ )
358528 return WeightUpdateResult (
359529 status = "ok" , version = body .version , duration_ms = duration_ms
360530 )
0 commit comments