Skip to content

Commit ccc6179

Browse files
authored
Merge branch 'main' into lora
2 parents f89157e + 7cad4da commit ccc6179

6 files changed

Lines changed: 45 additions & 24 deletions

File tree

areal/engine/vllm_ext/areal_vllm_server.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,33 @@ async def update_weight_xccl(raw_request: Request):
152152

153153

154154
@router.post("/areal_update_weights_lora_xccl")
155-
async def update_weight_lora_xccl(raw_request: Request):
155+
async def update_weight_lora_xccl(
156+
request: UpdateWeightsFromXcclRequestLora, raw_request: Request
157+
):
156158
logger.info("API server starts update_weight_lora via XCCL")
157159
llm = raw_request.app.state.engine_client
158160
ret_list = await llm.engine_core.call_utility_async(
159161
"areal_injected_update_weight_lora_xccl",
160162
)
163+
# Only touch the registry after weights are actually updated
164+
models_obj = raw_request.app.state.openai_serving_models
165+
new_name = request.lora_name
166+
lora_id = request.lora_int_id
167+
for old_name, req in list(models_obj.lora_requests.items()):
168+
if req.lora_int_id == lora_id:
169+
del models_obj.lora_requests[old_name]
170+
req.lora_name = new_name
171+
models_obj.lora_requests[new_name] = req
172+
logger.info(
173+
f"Updated LoRA name of openai_serving_models "
174+
f"from {old_name} -> {new_name}"
175+
)
176+
break
177+
else:
178+
logger.warning(
179+
f"LoRA adapter with int_id={lora_id} not found in "
180+
f"openai_serving_models.lora_requests"
181+
)
161182
return build_response(ret_list)
162183

163184

areal/engine/vllm_remote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def build_distributed_weight_update_requests(
185185
),
186186
HttpRequest(
187187
endpoint=update_endpoint,
188-
payload={},
188+
payload={} if not meta.use_lora else payload,
189189
),
190190
]
191191
)

areal/infra/rpc/rpc_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,15 +431,15 @@ def configure():
431431
try:
432432
data = request.get_json()
433433
if data is None:
434-
return jsonify({"detail": "Invalid JSON in request body"}), 400
434+
return jsonify({"error": "Invalid JSON in request body"}), 400
435435

436436
config = data.get("config")
437437
if config is None:
438-
return jsonify({"detail": "Missing 'config' field in request"}), 400
438+
return jsonify({"error": "Missing 'config' field in request"}), 400
439439

440440
rank = data.get("rank")
441441
if rank is None:
442-
return jsonify({"detail": "Missing 'rank' field in request"}), 400
442+
return jsonify({"error": "Missing 'rank' field in request"}), 400
443443

444444
config = deserialize_value(config)
445445
config: BaseExperimentConfig

areal/infra/scheduler/local.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -866,10 +866,10 @@ def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int):
866866
logger.info(f"Configuration successfully on worker '{worker_id}'")
867867
return
868868
elif response.status_code == 400:
869-
error_detail = response.json().get("detail", "Unknown error")
869+
error_detail = response.json().get("error", "Unknown error")
870870
raise WorkerConfigurationError(worker_id, error_detail, str(400))
871871
elif response.status_code == 500:
872-
error_detail = response.json().get("detail", "Unknown error")
872+
error_detail = response.json().get("error", "Unknown error")
873873
raise WorkerConfigurationError(worker_id, error_detail, str(500))
874874
else:
875875
raise WorkerConfigurationError(
@@ -1118,7 +1118,7 @@ async def create_engine(
11181118
elif response.status == 400:
11191119
# Import error or bad request
11201120
error_detail = (await response.json()).get(
1121-
"detail", "Unknown error"
1121+
"error", "Unknown error"
11221122
)
11231123
if "Failed to import" in error_detail:
11241124
raise EngineImportError(engine, error_detail)
@@ -1127,7 +1127,7 @@ async def create_engine(
11271127
elif response.status == 500:
11281128
# Engine initialization failed
11291129
error_detail = (await response.json()).get(
1130-
"detail", "Unknown error"
1130+
"error", "Unknown error"
11311131
)
11321132
raise EngineCreationError(worker_id, error_detail, 500)
11331133
else:
@@ -1399,15 +1399,15 @@ async def async_call_engine(
13991399
elif response.status == 400:
14001400
# Bad request (e.g., method doesn't exist) - don't retry
14011401
error_detail = (await response.json()).get(
1402-
"detail", "Unknown error"
1402+
"error", "Unknown error"
14031403
)
14041404
raise EngineCallError(
14051405
worker_id, method, error_detail, attempt
14061406
)
14071407
elif response.status == 500:
14081408
# Engine method failed - don't retry
14091409
error_detail = (await response.json()).get(
1410-
"detail", "Unknown error"
1410+
"error", "Unknown error"
14111411
)
14121412
raise EngineCallError(
14131413
worker_id, method, error_detail, attempt
@@ -1492,11 +1492,11 @@ def _handle_call_response(
14921492
return deserialized_result, False, None
14931493
elif response.status_code == 400:
14941494
# Bad request (e.g., method doesn't exist) - don't retry
1495-
error_detail = response.json().get("detail", "Unknown error")
1495+
error_detail = response.json().get("error", "Unknown error")
14961496
raise EngineCallError(worker_id, method, error_detail, attempt)
14971497
elif response.status_code == 500:
14981498
# Engine method failed - don't retry
1499-
error_detail = response.json().get("detail", "Unknown error")
1499+
error_detail = response.json().get("error", "Unknown error")
15001500
raise EngineCallError(worker_id, method, error_detail, attempt)
15011501
elif response.status_code == 503:
15021502
# Service unavailable - retry

areal/infra/scheduler/slurm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,10 @@ def _configure_worker(self, worker_info: SlurmWorkerInfo, worker_rank: int) -> N
309309
logger.info(f"Configuration successful on worker '{worker_id}'")
310310
return
311311
elif response.status_code == 400:
312-
error_detail = response.json().get("detail", "Unknown error")
312+
error_detail = response.json().get("error", "Unknown error")
313313
raise WorkerConfigurationError(worker_id, error_detail, str(400))
314314
elif response.status_code == 500:
315-
error_detail = response.json().get("detail", "Unknown error")
315+
error_detail = response.json().get("error", "Unknown error")
316316
raise WorkerConfigurationError(worker_id, error_detail, str(500))
317317
else:
318318
raise WorkerConfigurationError(
@@ -1329,15 +1329,15 @@ async def create_engine(
13291329
return result.get("result")
13301330
elif response.status == 400:
13311331
error_detail = (await response.json()).get(
1332-
"detail", "Unknown error"
1332+
"error", "Unknown error"
13331333
)
13341334
if "Failed to import" in error_detail:
13351335
raise EngineImportError(engine, error_detail)
13361336
else:
13371337
raise EngineCreationError(worker_id, error_detail, 400)
13381338
elif response.status == 500:
13391339
error_detail = (await response.json()).get(
1340-
"detail", "Unknown error"
1340+
"error", "Unknown error"
13411341
)
13421342
raise EngineCreationError(worker_id, error_detail, 500)
13431343
else:
@@ -1439,7 +1439,7 @@ def call_engine(
14391439
result = response.json()
14401440
return deserialize_value(result.get("result"))
14411441
elif response.status_code == 500:
1442-
error_detail = response.json().get("detail", "Unknown error")
1442+
error_detail = response.json().get("error", "Unknown error")
14431443
# Check if retryable
14441444
if attempt < max_retries and "timeout" in error_detail.lower():
14451445
last_error = f"Engine method timeout: {error_detail}"
@@ -1457,7 +1457,7 @@ def call_engine(
14571457
f"Worker temporarily unavailable, retry {attempt}/{max_retries}"
14581458
)
14591459
else:
1460-
error_detail = response.json().get("detail", "Unknown error")
1460+
error_detail = response.json().get("error", "Unknown error")
14611461
raise EngineCallError(
14621462
worker_id,
14631463
method,
@@ -1580,7 +1580,7 @@ async def async_call_engine(
15801580
return deserialize_value(result.get("result"))
15811581
elif response.status == 500:
15821582
error_detail = (await response.json()).get(
1583-
"detail", "Unknown error"
1583+
"error", "Unknown error"
15841584
)
15851585
if (
15861586
attempt < max_retries
@@ -1601,7 +1601,7 @@ async def async_call_engine(
16011601
)
16021602
else:
16031603
error_detail = (await response.json()).get(
1604-
"detail", "Unknown error"
1604+
"error", "Unknown error"
16051605
)
16061606
raise EngineCallError(
16071607
worker_id,

tests/test_local_scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,7 @@ def test_create_engine_import_error(self, scheduler, tmp_path):
13151315
mock_response = AsyncMock()
13161316
mock_response.status = 400
13171317
mock_response.json = AsyncMock(
1318-
return_value={"detail": "Failed to import 'nonexistent.Engine'"}
1318+
return_value={"error": "Failed to import 'nonexistent.Engine'"}
13191319
)
13201320
mock_response.__aenter__.return_value = mock_response
13211321
mock_response.__aexit__.return_value = None
@@ -1343,7 +1343,7 @@ def test_create_engine_initialization_error(self, scheduler, tmp_path):
13431343
mock_response = AsyncMock()
13441344
mock_response.status = 500
13451345
mock_response.json = AsyncMock(
1346-
return_value={"detail": "Engine initialization failed: out of memory"}
1346+
return_value={"error": "Engine initialization failed: out of memory"}
13471347
)
13481348
mock_response.__aenter__.return_value = mock_response
13491349
mock_response.__aexit__.return_value = None
@@ -1486,7 +1486,7 @@ def test_call_engine_method_error(self, scheduler, tmp_path):
14861486
scheduler._workers["test"] = [worker]
14871487

14881488
mock_response = create_mock_http_response(
1489-
status_code=400, json_data={"detail": "Method 'nonexistent' not found"}
1489+
status_code=400, json_data={"error": "Method 'nonexistent' not found"}
14901490
)
14911491

14921492
with patch.object(requests, "post", return_value=mock_response):

0 commit comments

Comments
 (0)