Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions slime/slime/backends/megatron_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import torch

from slime.utils.common import is_npu
if is_npu():
import mindspeed.megatron_adaptor

try:
import deep_ep
from torch_memory_saver import torch_memory_saver
Expand Down Expand Up @@ -39,4 +43,22 @@ def _patched_forward(self, *args, packed_seq_params=None, **kwargs):
except ImportError:
pass

try:
from mbridge.models.qwen3_vl.model import Qwen3VLModel
_original_forward2 = Qwen3VLModel.forward

def _patched_forward2(self, *args, loss_mask=None, **kwargs):
return _original_forward2(self, *args, **kwargs)
Qwen3VLModel.forward = _patched_forward2
except ImportError:
pass
try:
from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel
_original_forward3 = Qwen3VLModel.forward

def _patched_forward3(self, *args, loss_mask=None, **kwargs):
return _original_forward3(self, *args, **kwargs)
Qwen3VLModel.forward = _patched_forward3
except ImportError:
pass
logging.getLogger("megatron").setLevel(logging.WARNING)
18 changes: 16 additions & 2 deletions slime/slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
import ray
import torch
import torch.distributed as dist

from slime.utils.common import is_npu
if is_npu():
import mindspeed.megatron_adaptor
from mindspeed.megatron_adaptor import repatch
from megatron.core import mpu
from ray.actor import ActorHandle
from torch_memory_saver import torch_memory_saver
Expand Down Expand Up @@ -150,7 +155,10 @@ def _offload_rollout_data_to_cpu(rollout_data: RolloutBatch) -> None:
rollout_data[key] = [v.to("cpu", non_blocking=True) for v in vals]
moved_any = True
if moved_any:
torch.cuda.synchronize()
if not is_npu():
torch.cuda.synchronize()
else:
torch.npu.synchronize()


class MegatronTrainRayActor(TrainRayActor):
Expand Down Expand Up @@ -269,6 +277,8 @@ def init(
super().init(args, role, with_ref)

init(args)
if is_npu():
repatch(args)

if is_megatron_main_rank():
init_tracking(args, primary=False)
Expand Down Expand Up @@ -1054,8 +1064,12 @@ def connect_actor_critic(

group_name = "actor_critic"
world_size = 2
if is_npu():
backend = "hccl"
else:
backend = "nccl"
self._actor_critic_groups = init_process_group(
backend="nccl",
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=0 if self.role == "actor" else 1,
Expand Down
10 changes: 7 additions & 3 deletions slime/slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from slime.utils.metric_utils import compute_pass_rate, compute_rollout_step
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
from slime.utils.types import RolloutBatch
from slime.utils.common import is_npu

from ...utils import logging_utils
from .cp_utils import get_sum_of_sample_mean, slice_with_cp
Expand All @@ -31,9 +32,12 @@ def _to_cuda(val: object) -> object:
if val is None:
return None
if isinstance(val, torch.Tensor):
if val.is_cuda:
return val
return val.to(device=torch.cuda.current_device(), non_blocking=True)
if is_npu():
return val.to(device=torch.npu.current_device(), non_blocking=True)
else:
if val.is_cuda:
return val
return val.to(device=torch.cuda.current_device(), non_blocking=True)
if isinstance(val, list):
return [_to_cuda(v) for v in val]
if isinstance(val, tuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
f'-gencode=arch=compute_{arch.replace(".", "")},code=sm_{arch.replace(".", "")}'
for arch in arch_list
]
+ ["-gencode=arch=compute_90a,code=sm_90a"],
+ (["-gencode=arch=compute_90a,code=sm_90a"] if not hasattr(torch,'npu') else []),
},
)
],
Expand Down
10 changes: 7 additions & 3 deletions slime/slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from slime.utils.distributed_utils import distributed_masked_whiten
from slime.utils.misc import load_function
from slime.utils.common import is_npu

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -530,8 +531,11 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
# loss_masks live on CPU (lazy-loading optimisation). We need GPU copies
# for the advantage / KL / normalisation math below. The original CPU
# tensors in rollout_data["loss_masks"] are NOT modified.
if loss_masks and isinstance(loss_masks[0], torch.Tensor) and not loss_masks[0].is_cuda:
_gpu = torch.cuda.current_device()
if loss_masks and isinstance(loss_masks[0], torch.Tensor) and loss_masks[0].is_cpu:
if is_npu():
_gpu = torch.npu.current_device()
else:
_gpu = torch.cuda.current_device()
loss_masks = [m.to(device=_gpu) for m in loss_masks]

if args.kl_coef == 0 or not log_probs:
Expand Down Expand Up @@ -1198,7 +1202,7 @@ def loss_function(

return (
loss,
(num_tokens if args.calculate_per_token_loss else torch.tensor(1, device=logits.device)),
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device),
{
"keys": list(log.keys()),
"values": torch.tensor(
Expand Down
6 changes: 5 additions & 1 deletion slime/slime/backends/megatron_utils/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Adapt from https://github.com/NVIDIA/Megatron-LM/blob/b1efb3c7126ef7615e8c333432d76e08038e17ff/pretrain_gpt.py
import argparse
import inspect
import re
from contextlib import nullcontext
from typing import Literal

Expand Down Expand Up @@ -114,6 +113,11 @@ def wrapped_model_provider(
provider.recompute_method = args.recompute_method
provider.recompute_num_layers = args.recompute_num_layers

for key, value in vars(args).items():
if hasattr(provider, key):
continue
setattr(provider, key, value)

# CLI flags that materially affect train numerics/quality and per-step
# speed but are NOT derivable from the HF config. Without these, bridge
# mode silently keeps HF-config defaults (e.g. attention_dropout=0.1
Expand Down
5 changes: 5 additions & 0 deletions slime/slime/backends/megatron_utils/update_weight/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from slime.backends.megatron_utils.misc_utils import strip_param_name_prefix
from slime.utils.types import ParamInfo
from slime.utils.common import is_npu


_DISABLE_LINEAR_FC1_RECHUNK = os.getenv("SLIME_QWEN35_DISABLE_LINEAR_FC1_RECHUNK", "0") == "1"

Expand Down Expand Up @@ -41,6 +43,9 @@ def _merge_tp_partitions(
if "linear_fc1.weight" in name and not _DISABLE_LINEAR_FC1_RECHUNK:
param_partitions = [p.chunk(2, dim=0) for p in param_partitions]
param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions]
# TODO: Temporary workaround for NPU to set partition_dim to 0
if is_npu():
partition_dim = 0
# this is bug in megatron's grouped moe.
if "linear_fc2.weight" in name and partition_dim == 0:
partition_dim = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tqdm import tqdm

from slime.utils.distributed_utils import get_gloo_group, init_process_group
from slime.utils.common import is_npu

from ..megatron_to_hf import convert_to_hf
from .common import all_gather_param, named_params_and_buffers
Expand Down Expand Up @@ -253,19 +254,20 @@ def connect_rollout_engines_from_distributed(
master_port = sock.getsockname()[1]
world_size = len(rollout_engines) * args.rollout_num_gpus_per_engine + 1

backend = "hccl" if is_npu() else "nccl"
refs = [
engine.init_weights_update_group.remote(
master_address,
master_port,
i * args.rollout_num_gpus_per_engine + 1,
world_size,
group_name,
backend="nccl",
backend=backend,
)
for i, engine in enumerate(rollout_engines)
]
model_update_groups = init_process_group(
backend="nccl",
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=0,
Expand Down
6 changes: 5 additions & 1 deletion slime/slime/backends/sglang_utils/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .qwen3_5 import is_qwen35_model_path, maybe_prepare_qwen35_text_model, patch_sglang_qwen35
from slime.ray.ray_actor import RayActor
from slime.utils.http_utils import get_host_info
from slime.utils.common import is_npu

logger = logging.getLogger(__name__)

Expand All @@ -34,7 +35,10 @@ def get_base_gpu_id(args, rank):


def _to_local_gpu_id(physical_gpu_id: int) -> int:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if is_npu():
cvd = os.environ.get("ASCEND_RT_VISIBLE_DEVICES")
else:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if not cvd:
return physical_gpu_id # no remapping
# CUDA_VISIBLE_DEVICES can be like "4,5,6,7"
Expand Down
7 changes: 4 additions & 3 deletions slime/slime/ray/actor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from slime.ray.utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST
from slime.utils.common import is_npu


class RayTrainGroup:
Expand Down Expand Up @@ -87,19 +88,19 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor):

actor_impl = FSDPTrainRayActor

TrainRayActor = ray.remote(num_gpus=1, runtime_env={"env_vars": env_vars})(actor_impl)

TrainRayActor = ray.remote(runtime_env={"env_vars": env_vars})(actor_impl)
device_name = "NPU" if is_npu() else "GPU"
# Create worker actors
self._actor_handlers = []
master_addr, master_port = None, None
for rank in range(world_size):
actor = TrainRayActor.options(
num_cpus=num_gpus_per_actor,
num_gpus=num_gpus_per_actor,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=reordered_bundle_indices[rank],
),
resources={device_name: num_gpus_per_actor}
).remote(world_size, rank, master_addr, master_port)
if rank == 0:
master_addr, master_port = ray.get(actor.get_master_addr_and_port.remote())
Expand Down
18 changes: 13 additions & 5 deletions slime/slime/ray/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from slime.utils.common import is_npu

from .actor_group import RayTrainGroup
from .rollout import RolloutManager

logger = logging.getLogger(__name__)


@ray.remote(num_gpus=1)
@ray.remote
class InfoActor:
def get_ip_and_gpu_id(self):
return ray.util.get_node_ip_address(), ray.get_gpu_ids()[0]
if is_npu():
return ray.util.get_node_ip_address(), ray.get_runtime_context().get_accelerator_ids()["NPU"][0]
else:
return ray.util.get_node_ip_address(), ray.get_gpu_ids()[0]


def sort_key(x):
Expand All @@ -35,12 +39,13 @@ def sort_key(x):
# representation that allows for sorting.
node_ip_parts = [ord(c) for c in node_identifier]

return (node_ip_parts, gpu_id)
return (node_ip_parts, int(gpu_id))


def _create_placement_group(num_gpus):
"""Create a placement group with the specified number of GPUs."""
bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)]
device_name = "NPU" if is_npu() else "GPU"
bundles = [{device_name: 1, "CPU": 1} for _ in range(num_gpus)]
pg = placement_group(bundles, strategy="PACK")
num_bundles = len(bundles)

Expand All @@ -53,7 +58,8 @@ def _create_placement_group(num_gpus):
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=i,
)
),
resources={device_name: 1}
).remote()
)
gpu_ids = ray.get([actor.get_ip_and_gpu_id.remote() for actor in info_actors])
Expand Down Expand Up @@ -201,9 +207,11 @@ def create_training_models(args, pgs, rollout_manager):


def create_rollout_manager(args, pg, prm_pg=None):
device_name = "NPU" if is_npu() else "GPU"
rollout_manager = RolloutManager.options(
num_cpus=1,
num_gpus=0,
resources={device_name: 0}
).remote(args, pg, prm_pg)

# calculate num_rollout from num_epoch
Expand Down
11 changes: 8 additions & 3 deletions slime/slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from slime.utils.misc import Box, group_by, load_function
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
from slime.utils.types import Sample
from slime.utils.common import is_npu

from ..utils.metric_utils import has_repetition
from .utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, Lock
Expand Down Expand Up @@ -89,7 +90,8 @@ def __init__(self, args, pg, prm_pg=None):
self.all_prm_engines = []
self.num_new_prm_engines = 0
self.nodes_per_engine = max(1, args.rollout_num_gpus_per_engine // args.num_gpus_per_node)
self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote()
device_name = "NPU" if is_npu() else "GPU"
self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0, resources={device_name: 0}).remote()
self.rollout_id = -1

self._metric_checker = MetricChecker.maybe_create(args)
Expand Down Expand Up @@ -830,6 +832,7 @@ def init_rollout_engines(args, pg, all_rollout_engines):
RolloutRayActor = ray.remote(SGLangEngine)

rollout_engines = []
device_name = "NPU" if is_npu() else "GPU"
for i in range(num_engines):
if all_rollout_engines[i] is not None:
continue
Expand All @@ -849,6 +852,7 @@ def init_rollout_engines(args, pg, all_rollout_engines):
env_vars = {name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST} | {
key: os.environ.get(key, default_val)
for key, default_val in {
"SGL_JIT_DEEPGEMM_PRECOMPILE": "false",
"SGLANG_JIT_DEEPGEMM_PRECOMPILE": "false",
"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK": "true",
"SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK": "true",
Expand All @@ -868,11 +872,11 @@ def init_rollout_engines(args, pg, all_rollout_engines):

rollout_engine = RolloutRayActor.options(
num_cpus=num_cpus,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
runtime_env={
"env_vars": env_vars,
},
resources={device_name: num_gpus}
).remote(args, rank=i, worker_type=worker_type, base_gpu_id=base_gpu_id)

rollout_engines.append((i, rollout_engine))
Expand Down Expand Up @@ -937,11 +941,12 @@ def init_prm_engines(args, pg, all_prm_engines):
}.items()
}

device_name = "NPU" if is_npu() else "GPU"
prm_engine = RolloutRayActor.options(
num_cpus=num_cpus,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
runtime_env={"env_vars": env_vars},
resources={device_name: num_gpus}
).remote(args, rank=i, worker_type="regular", base_gpu_id=base_gpu_id, engine_role="prm")

prm_engines.append((i, prm_engine))
Expand Down
Loading