Skip to content
129 changes: 105 additions & 24 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
from lmdeploy.utils import get_logger

from ..moe import DlinferMoECommType, DlinferMoeMetadata
Expand Down Expand Up @@ -156,8 +157,23 @@ def update_step_context(cls, step_context):

block_num, block_size, *_ = step_context.kv_caches[0][0].shape
is_prefill_no_cache = False
num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens

if not step_context.is_decoding:
is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())
is_multi_token_decoding = False
is_decoding = False
else:
is_multi_token_decoding = step_context.q_seqlens.max().item() > 1
# is_decoding: True only for regular single-token decode (original semantics)
is_decoding = not is_multi_token_decoding

# MoE EP dispatch/combine and graph capture are collective ops shared by all
# DP ranks, so they must agree on decode-vs-prefill. Use the DP-global state
# (if any rank is prefill, all ranks are prefill) for those paths; the local
# is_decoding / is_multi_token_decoding above stay rank-local for attention.
global_is_decoding = step_context.global_is_decoding()

if step_context.block_offsets.dtype != torch.int32:
step_context.block_offsets = step_context.block_offsets.to(torch.int32)
if step_context.kv_seqlens.dtype != torch.int32:
Expand All @@ -180,8 +196,6 @@ def get_cpu_seqlens(is_decoding, is_prefill_no_cache):
q_seqlens_cpu: query sequence lengths (per sequence).
kv_seqlens_cpu: kv sequence lengths (per sequence), used for
list/max seqlens calculation.
kv_seqlens_expanded: kv sequence lengths expanded per token via
repeat_interleave, used for attention metadata.
"""
if is_decoding:
q_seqlens_cpu = None
Expand Down Expand Up @@ -219,7 +233,8 @@ def update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None):
return torch.arange(1, batch_size + 1, dtype=torch.int32)
elif is_prefill_no_cache:
return q_seqlens_cpu
return q_seqlens_cpu.cumsum(dim=0)
# for paged_prefill, eg. MTP, prefix caching
return q_seqlens_cpu.cumsum(dim=0).to(torch.int32)

def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list,
max_q_seq_len, max_kv_seq_len):
Expand Down Expand Up @@ -277,12 +292,29 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group):
if ep_size <= 1:
return 0, 0, 0
# get padded_tokens_current_rank
is_graph = cls.enable_graph and step_context.is_decoding
is_graph = cls.enable_graph and global_is_decoding and (is_decoding or is_multi_token_decoding)
if is_graph:
from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size
actual_tokens_current_rank = step_context.q_seqlens.shape[0]
padded_tokens_current_rank = min(get_ascend_compatible_size(actual_tokens_current_rank),
cls.max_batches)
# The cudagraph is keyed/captured on the GLOBAL padded batch
# (max over all DP ranks), so every DP rank executes the MoE with
# the same global token count. padded_tokens_current_rank must
# therefore mirror that global captured size; deriving it from this
# rank's local batch makes DP ranks disagree on the MC2
# dispatch/combine token count and corrupts the collective
# (MoeDistributeCombineV2 AICORE out-of-bounds). dp_meta.dp_batches
# holds the per-rank sequence counts; its max is the global batch
# the graph capture uses.
dp_meta = step_context.dp_meta
if dp_meta is not None and dp_meta.dp_batches:
global_batch = max(dp_meta.dp_batches)
else:
global_batch = step_context.q_seqlens.shape[0]
query_len = (num_spec_tokens + 1) if is_multi_token_decoding else 1
# actual tokens: this rank's real (non-padded) token count, used to
# build x_active_mask so MC2 ignores the graph padding region.
actual_tokens_current_rank = step_context.q_seqlens.sum().item()
padded_tokens_current_rank = min(get_ascend_compatible_size(global_batch),
cls.max_batches) * query_len
else:
actual_tokens_current_rank = step_context.q_seqlens.sum().item()
padded_tokens_current_rank = actual_tokens_current_rank
Expand All @@ -303,15 +335,15 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group):

@lru_cache
def init_mc2_token_capacity(tp_size):
max_num_tokens = min(cls.max_batches, 512)
max_num_tokens = min(cls.max_batches * (num_spec_tokens + 1), 512)
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
return num_tokens_per_tp_rank * tp_size

def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size):
if ep_size <= 1:
return DlinferMoECommType.ALLGATHER
mc2_token_capacity = init_mc2_token_capacity(tp_size)
is_graph = cls.enable_graph and step_context.is_decoding
is_graph = cls.enable_graph and global_is_decoding and (is_decoding or is_multi_token_decoding)
if is_graph:
max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size
if SocVersion.is_A2():
Expand All @@ -320,7 +352,7 @@ def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size):
else:
return DlinferMoECommType.ALLGATHER
elif SocVersion.is_A3():
if max_tokens_across_dp <= mc2_token_capacity:
if max_tokens_across_dp <= mc2_token_capacity and global_is_decoding:
return DlinferMoECommType.MC2
else:
return DlinferMoECommType.ALLTOALL
Expand All @@ -337,7 +369,7 @@ def get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank, max_tok
dtype=torch.bool,
device=torch.npu.current_device())
elif moe_comm_type == DlinferMoECommType.ALLTOALL:
pad_size = tp_size - padded_tokens_current_rank
pad_size = (-padded_tokens_current_rank) % tp_size
elif moe_comm_type == DlinferMoECommType.ALLGATHER:
pad_size = max_tokens_across_dp - padded_tokens_current_rank
else:
Expand All @@ -353,17 +385,17 @@ def get_moe_group_name(group):
group_name = backend.get_hccl_comm_name(local_rank)
return group_name

q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_prefill_no_cache)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu,
q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(is_decoding, is_prefill_no_cache)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu,
kv_seqlens_cpu)
max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_list,
max_q_seq_len, max_kv_seq_len = get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list,
kv_seqlens_list)
kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding,
kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(is_decoding,
is_prefill_no_cache, q_seqlens_list,
kv_seqlens_list, max_q_seq_len,
max_kv_seq_len)
q_seqlens_cpu = update_q_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu)

q_seqlens_cpu = update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu)
if not cls.enable_graph and step_context.kv_quant_policy == 8:
record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE'
Expand All @@ -379,18 +411,57 @@ def get_moe_group_name(group):

cu_seqlens = None
has_initial_state = None

spec_conv_offsets = None
spec_state_offsets = None
cache_seqlens = None
is_gated_delta = step_context.model_config.is_gated_delta
if is_gated_delta:
q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype,
device=step_context.q_seqlens.device)
cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()
if not step_context.is_decoding:
has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens)
q_seqlens = step_context.q_seqlens
kv_seqlens = step_context.kv_seqlens

q_start_loc = step_context.q_start_loc.to(dtype=q_seqlens.dtype,
device=q_seqlens.device)
cu_seqlens = torch.cat((q_start_loc, q_seqlens.sum().unsqueeze(0))).int()
cache_seqlens = (kv_seqlens - q_seqlens).contiguous()


states_shapes = step_context.model_config.states_shapes
if not is_decoding and not is_multi_token_decoding and len(states_shapes) > 0:
has_initial_state = ~(q_seqlens == kv_seqlens)
# # Conv ring buffer: conv_state_len = conv_kernel_size + num_spec_tokens.
conv_state_len = states_shapes[0][0][0]
conv_kernel_size = conv_state_len - num_spec_tokens

if num_spec_tokens > 0:
state_slots = 1 + num_spec_tokens
spec_state_offsets = (
torch.remainder(cache_seqlens, state_slots),
torch.remainder(kv_seqlens, state_slots),
)

range_idx = torch.arange(
-conv_kernel_size,
0,
device=cache_seqlens.device,
dtype=torch.int32,
)
# Read the (conv_kernel_size - 1) tokens preceding the current write
# window from the circular buffer.
read_conv_offsets = torch.remainder(
cache_seqlens[:, None] + range_idx[1:][None],
conv_state_len,
).to(torch.int64)
# Write the last conv_kernel_size tokens of this prefill batch into
# circular-buffer slots so the next decode read aligns naturally.
write_conv_offsets = torch.remainder(
kv_seqlens[:, None] + range_idx[None],
conv_state_len,
).to(torch.int64)
spec_conv_offsets = (read_conv_offsets, write_conv_offsets)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
is_decoding,
step_context.block_offsets,
# cu_seqlens is only used in GDN and is passed down via q_start_loc.
# Otherwise, q_start_loc is None.
Expand All @@ -406,6 +477,10 @@ def get_moe_group_name(group):
quant_policy=step_context.kv_quant_policy,
quant_meta=AscendKVQuantMeta.quant_meta,
has_initial_state=has_initial_state,
is_multi_token_decoding=is_multi_token_decoding,
spec_conv_offsets=spec_conv_offsets,
spec_state_offsets=spec_state_offsets,
cache_seqlens=cache_seqlens,
)
step_context.attn_metadata = attn_metadata

Expand Down Expand Up @@ -462,6 +537,12 @@ def init():
logger.warning(f'Error during Ascend initialization: {str(e)}. '
'Please check your Ascend environment configuration.')

try:
import dlinfer.framework.lmdeploy_ext.device # noqa: F401 — triggers vendor_device_init()
except ImportError:
logger.warning('dlinfer framework extensions not found. '
'Ascend-specific model patches will not be applied.')

try:
from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton
init_device_properties_triton()
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class DlinferAttentionMetadata(AttentionMetadata):
quant_meta: dict = None
cu_seq_lens_kv: Tensor | None = None
has_initial_state: Tensor | None = None
is_multi_token_decoding: bool = False
spec_conv_offsets: Sequence[Tensor] = tuple()
spec_state_offsets: Sequence[Tensor] = tuple()
cache_seqlens: Tensor | None = None


class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ def from_config(
block_size=target_cache_cfg.block_size,
model_format=model_format,
hf_overrides=hf_overrides,
device_type=target_cache_cfg.device_type,
)
cache_config = None
# include medusa
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def warmup(self):
if dp > 1:
num_tokens = inputs.input_ids.numel()
inputs.build_dp_meta([num_tokens] * world_size)
inputs.dp_meta.dp_is_decoding = False
logger.debug('Warmup prefill start.')
self._forward_impl(inputs)
torch.cuda.synchronize()
Expand All @@ -423,6 +424,7 @@ def warmup(self):
if dp > 1:
num_tokens = inputs.input_ids.numel()
inputs.build_dp_meta([num_tokens] * world_size)
inputs.dp_meta.dp_is_decoding = True
logger.debug(f'Warmup decoding num_tokens={num_tokens} start.')
self._forward_impl(inputs)
torch.cuda.synchronize()
Expand Down
Loading