Skip to content

Commit d8d8e4d

Browse files
committed
Fix ART ty failures in local backend and MoE conversion
1 parent 86ae933 commit d8d8e4d

2 files changed

Lines changed: 31 additions & 4 deletions

File tree

src/art/local/backend.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ def _allocated_gpu_count(self, model: Model) -> int:
162162
def __enter__(self) -> Self:
163163
return self
164164

165+
async def __aenter__(self) -> Self:
166+
return self
167+
165168
def __exit__(
166169
self,
167170
exc_type: type[BaseException] | None,
@@ -170,11 +173,19 @@ def __exit__(
170173
) -> None:
171174
self._close()
172175

176+
async def __aexit__(
177+
self,
178+
exc_type: type[BaseException] | None,
179+
exc: BaseException | None,
180+
tb: TracebackType | None,
181+
) -> None:
182+
await self.close()
183+
173184
async def close(self) -> None:
174185
"""
175186
If running vLLM in a separate process, this will kill that process and close the communication threads.
176187
"""
177-
self._close()
188+
await self._aclose()
178189

179190
def _close(self) -> None:
180191
for _, service in self._services.items():
@@ -183,6 +194,17 @@ def _close(self) -> None:
183194
close()
184195
close_proxy(service)
185196

197+
async def _aclose(self) -> None:
198+
for _, service in self._services.items():
199+
aclose = getattr(service, "aclose", None)
200+
if aclose is not None:
201+
await aclose()
202+
else:
203+
close = getattr(service, "close", None)
204+
if close is not None:
205+
close()
206+
close_proxy(service)
207+
186208
async def register(
187209
self,
188210
model: Model,
@@ -505,6 +527,7 @@ async def train( # type: ignore[override]
505527
*,
506528
# Core training parameters
507529
learning_rate: float = 5e-6,
530+
loss_fn: Literal["cispo", "ppo"] | None = None,
508531
# KL-penalized advantage adjustment
509532
kl_penalty_coef: float = 0.0,
510533
kl_penalty_reference_step: int | None = None,
@@ -600,6 +623,8 @@ async def train( # type: ignore[override]
600623
# await model.log(metrics=result.metrics, step=result.step)
601624
"""
602625
groups_list = list(trajectory_groups)
626+
if loss_fn is not None:
627+
ppo = loss_fn == "ppo"
603628

604629
resolved_kl_ref_adapter_path = kl_ref_adapter_path
605630
if (

src/art/utils/convert_moe_lora.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
...
1313
"""
1414

15+
import importlib
1516
import json
1617
import os
1718
import re
1819

19-
import safetensors.torch
2020
import torch
2121

22+
safetensors_torch = importlib.import_module("safetensors.torch")
23+
2224

2325
def _has_fused_moe_lora(tensors: dict[str, torch.Tensor]) -> bool:
2426
"""Check if the adapter contains fused MoE LoRA tensors."""
@@ -152,7 +154,7 @@ def convert_checkpoint_if_needed(checkpoint_dir: str) -> None:
152154
if not os.path.exists(adapter_path) or not os.path.exists(config_path):
153155
return
154156

155-
tensors = safetensors.torch.load_file(adapter_path)
157+
tensors = safetensors_torch.load_file(adapter_path)
156158
if not _has_fused_moe_lora(tensors):
157159
return
158160

@@ -168,7 +170,7 @@ def convert_checkpoint_if_needed(checkpoint_dir: str) -> None:
168170
)
169171

170172
# Overwrite the adapter with the converted tensors
171-
safetensors.torch.save_file(new_tensors, adapter_path)
173+
safetensors_torch.save_file(new_tensors, adapter_path)
172174

173175
# Update adapter_config.json target_modules
174176
adapter_config["target_modules"] = [

0 commit comments

Comments
 (0)