File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -451,8 +451,7 @@ async def offload_states_partial(self, target_dp_ranks: List[int]):
451451
452452 # Verify offloaded workers have near-zero GPU memory usage
453453 if self .rank_info .dp_rank in target_dp_ranks :
454- import torch
455- gpu_memory_gb = torch .cuda .memory_allocated () / 1024 ** 3
454+ gpu_memory_gb = current_platform .memory_allocated () / 1024 ** 3
456455 if gpu_memory_gb > 1.0 :
457456 raise RuntimeError (
458457 f"GPU memory not properly offloaded for Worker { self .rank } (DP { self .rank_info .dp_rank } ): "
@@ -501,7 +500,7 @@ async def generate(self, data: DataProto):
501500 global_step = data .meta_info .get ("global_step" , 0 )
502501 self .logger .info (f"{ self .worker_name } generate global step { global_step } " )
503502
504- data = data .to ("cuda" )
503+ data = data .to (current_platform . device_type )
505504 data .meta_info ["micro_batch_size" ] = self .worker_config .infer_batch_size
506505
507506 output = await self .strategy .generate (batch = data , generation_config = generation_config )
You can’t perform that action at this time.
0 commit comments