Skip to content

Commit 1e7f794

Browse files
committed
fix: add base worker npu support
1 parent df7d186 commit 1e7f794

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

roll/pipeline/base_worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,7 @@ async def offload_states_partial(self, target_dp_ranks: List[int]):
451451

452452
# Verify offloaded workers have near-zero GPU memory usage
453453
if self.rank_info.dp_rank in target_dp_ranks:
454-
import torch
455-
gpu_memory_gb = torch.cuda.memory_allocated() / 1024**3
454+
gpu_memory_gb = current_platform.memory_allocated() / 1024**3
456455
if gpu_memory_gb > 1.0:
457456
raise RuntimeError(
458457
f"GPU memory not properly offloaded for Worker {self.rank} (DP {self.rank_info.dp_rank}): "
@@ -501,7 +500,7 @@ async def generate(self, data: DataProto):
501500
global_step = data.meta_info.get("global_step", 0)
502501
self.logger.info(f"{self.worker_name} generate global step {global_step}")
503502

504-
data = data.to("cuda")
503+
data = data.to(current_platform.device_type)
505504
data.meta_info["micro_batch_size"] = self.worker_config.infer_batch_size
506505

507506
output = await self.strategy.generate(batch=data, generation_config=generation_config)

0 commit comments

Comments
 (0)