Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ dist
.vscode
tmp/
requirements-musa.txt
logs/
logs/

/benchmark/
artifacts/
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi
return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list)


def get_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend:
def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_decode_att_backend[index]
Expand All @@ -120,7 +120,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl
return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list)


def get_mla_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend:
def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_decode_att_backend[index]
Expand Down
35 changes: 34 additions & 1 deletion lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from typing import Optional, TYPE_CHECKING
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.sgl_utils import flash_attn_with_kvcache, get_scheduler_metadata
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor


_DECODE_MAX_NUM_SPLITS = 32
_DECODE_PACK_GQA = True


class Fa3AttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
Expand Down Expand Up @@ -119,6 +123,7 @@ class Fa3DecodeAttState(BaseDecodeAttState):
cu_seqlens_k: torch.Tensor = None
page_table: torch.Tensor = None
b_att_seq_len: torch.Tensor = None
scheduler_metadata: torch.Tensor = None
# 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。
decode_max_q_seq_len: int = None

Expand Down Expand Up @@ -179,8 +184,33 @@ def init_state(self):
)
self.b_att_seq_len = self.infer_state.b_seq_len
self.decode_max_q_seq_len = 1
self._init_scheduler_metadata()
return

def _init_scheduler_metadata(self):
if get_scheduler_metadata is None:
self.scheduler_metadata = None
return

model = self.backend.model
self.scheduler_metadata = get_scheduler_metadata(
batch_size=self.b_att_seq_len.shape[0],
max_seqlen_q=self.decode_max_q_seq_len,
max_seqlen_k=self.infer_state.max_kv_seq_len,
num_heads=model.config["num_attention_heads"] // model.tp_world_size_,
num_heads_k=model.tp_k_head_num_,
headdim=model.head_dim_,
cache_seqlens=self.b_att_seq_len,
qkv_dtype=model.data_type,
headdim_v=model.head_dim_,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
page_size=1,
causal=True,
num_splits=_DECODE_MAX_NUM_SPLITS,
pack_gqa=_DECODE_PACK_GQA,
)

def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)

Expand Down Expand Up @@ -235,6 +265,9 @@ def _normal_decode_att(
causal=True,
window_size=window_size,
softcap=0.0,
scheduler_metadata=self.scheduler_metadata,
num_splits=_DECODE_MAX_NUM_SPLITS,
pack_gqa=_DECODE_PACK_GQA,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=False,
Expand Down
28 changes: 15 additions & 13 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def init_state(self):
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
)
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)

self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

def prefill_att(
self,
Expand Down Expand Up @@ -116,20 +119,19 @@ def init_state(self):
super().init_state()
self.backend: Fp8Fa3AttBackend = self.backend

args_mtp_step = get_env_start_args().mtp_step
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0

device = self.infer_state.input_ids.device
batch_size = att_batch_size
batch_size = self.b_att_seq_len.shape[0]
mem_manager = self.backend.model.mem_manager

offline_scales: torch.Tensor = mem_manager.scales
head_num = mem_manager.head_num

# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

return

Expand Down Expand Up @@ -180,11 +182,11 @@ def _fp8_decode_att(
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cache_seqlens=self.b_att_seq_len,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.decode_max_q_seq_len,
causal=False,
causal=True,
window_size=(-1, -1),
softcap=0.0,
q_descale=q_scale.view(self.infer_state.batch_size, k_head_num),
Expand Down
106 changes: 100 additions & 6 deletions lightllm/common/basemodel/attention/flashinfer/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,82 @@
from .env_utils import set_flashinfer_envs


def _fast_plan_tensor_core_decode(
decode_wrapper,
indptr,
indices,
last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
indptr_host,
kv_lens_arr_host,
max_kv_len,
):
batch_size = len(last_page_len)
if batch_size != decode_wrapper._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
"mismatches the batch size set during initialization {}".format(
batch_size, decode_wrapper._fixed_batch_size
)
)
if len(indices) > len(decode_wrapper._paged_kv_indices_buf):
raise ValueError("The size of indices should be less than or equal to the allocated buffer")

qo_indptr_host = getattr(decode_wrapper, "_lightllm_qo_indptr_host", None)
if qo_indptr_host is None or len(qo_indptr_host) != batch_size + 1:
from flashinfer.decode import _get_range_buf

qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
decode_wrapper._lightllm_qo_indptr_host = qo_indptr_host

if indptr_host is None:
indptr_host = indptr.to("cpu")
if kv_lens_arr_host is None:
from flashinfer.decode import get_seq_lens

last_page_len_host = last_page_len.to("cpu")
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
if max_kv_len is None:
max_kv_len = max(kv_lens_arr_host).item()

decode_wrapper._batch_size = batch_size
decode_wrapper._num_qo_heads = num_qo_heads
decode_wrapper._num_kv_heads = num_kv_heads
decode_wrapper._block_tables = None
decode_wrapper._max_kv_len = max_kv_len

args = [
decode_wrapper._float_workspace_buffer,
decode_wrapper._int_workspace_buffer,
decode_wrapper._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
decode_wrapper.is_cuda_graph_enabled,
head_dim,
head_dim,
False,
-1,
]
if decode_wrapper._backend == "fa2":
args.extend([-1, False, 0])
decode_wrapper._plan_info = decode_wrapper._cached_module.plan(*args)
decode_wrapper._pos_encoding_mode = "NONE"
decode_wrapper._window_left = -1
decode_wrapper._logits_soft_cap = 0.0
decode_wrapper._sm_scale = None
decode_wrapper._rope_scale = None
decode_wrapper._rope_theta = None


class FlashInferAttBackend(BaseAttBackend):
def __init__(self, model):
set_flashinfer_envs()
Expand All @@ -25,6 +101,10 @@ def __init__(self, model):
model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id()
),
]
self.kv_starts_host_buffer = [
torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"),
torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"),
]
self.q_data_type = model.data_type
self.kv_data_type = model.data_type

Expand Down Expand Up @@ -124,11 +204,11 @@ class FlashInferDecodeAttState(BaseDecodeAttState):
kv_last_page_len_buffer: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_starts: torch.Tensor = None
kv_starts_host: torch.Tensor = None
kv_seq_lens_host: torch.Tensor = None
decode_wrapper: object = None

def init_state(self):
import flashinfer

self.backend: FlashInferAttBackend = self.backend
device = self.infer_state.input_ids.device
model = self.backend.model
Expand All @@ -154,8 +234,21 @@ def init_state(self):
self.infer_state.b_kv_start_loc,
self.infer_state.max_kv_seq_len,
self.kv_indices,
zero_output=False,
)
self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int()
if self.infer_state.b_seq_len_cpu is not None:
self.kv_seq_lens_host = self.infer_state.b_seq_len_cpu
self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][
: self.infer_state.batch_size + 1
]
self.kv_starts_host[0] = 0
torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:])
if self.infer_state.skip_decode_att_wrapper_init:
return

import flashinfer

assert self.decode_wrapper is None
self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
self.backend.workspace_buffer,
Expand All @@ -182,17 +275,18 @@ def init_state(self):

def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
self.decode_wrapper.plan(
_fast_plan_tensor_core_decode(
self.decode_wrapper,
new_state.kv_starts,
new_state.kv_indices,
new_state.kv_last_page_len_buffer,
new_state.backend.tp_q_head_num,
new_state.backend.tp_kv_head_num,
new_state.backend.head_dim,
1,
q_data_type=new_state.backend.q_data_type,
kv_data_type=new_state.backend.kv_data_type,
non_blocking=True,
new_state.kv_starts_host,
new_state.kv_seq_lens_host,
new_state.infer_state.max_kv_seq_len,
)

def decode_att(
Expand Down
8 changes: 6 additions & 2 deletions lightllm/common/basemodel/attention/flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ class MlaFlashInferDecodeAttState(BaseDecodeAttState):
decode_wrapper: object = None

def init_state(self):
import flashinfer

self.backend: MlaFlashInferAttBackend = self.backend
model = self.backend.model
device = self.infer_state.input_ids.device
Expand All @@ -144,7 +142,13 @@ def init_state(self):
self.infer_state.b_kv_start_loc,
self.infer_state.max_kv_seq_len,
self.kv_indices,
zero_output=False,
)
if self.infer_state.skip_decode_att_wrapper_init:
return

import flashinfer

assert self.decode_wrapper is None

self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
Expand Down
Loading