Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi
return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list)


def get_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend:
def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_decode_att_backend[index]
Expand All @@ -120,7 +120,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl
return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list)


def get_mla_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend:
def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_decode_att_backend[index]
Expand Down
35 changes: 34 additions & 1 deletion lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from typing import Optional, TYPE_CHECKING
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.sgl_utils import flash_attn_with_kvcache, get_scheduler_metadata
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor


_DECODE_MAX_NUM_SPLITS = 32
_DECODE_PACK_GQA = True


class Fa3AttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
Expand Down Expand Up @@ -119,6 +123,7 @@ class Fa3DecodeAttState(BaseDecodeAttState):
cu_seqlens_k: torch.Tensor = None
page_table: torch.Tensor = None
b_att_seq_len: torch.Tensor = None
scheduler_metadata: torch.Tensor = None
# 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。
decode_max_q_seq_len: int = None

Expand Down Expand Up @@ -179,8 +184,33 @@ def init_state(self):
)
self.b_att_seq_len = self.infer_state.b_seq_len
self.decode_max_q_seq_len = 1
self._init_scheduler_metadata()
return

def _init_scheduler_metadata(self):
if get_scheduler_metadata is None:
self.scheduler_metadata = None
return

model = self.backend.model
self.scheduler_metadata = get_scheduler_metadata(
batch_size=self.b_att_seq_len.shape[0],
max_seqlen_q=self.decode_max_q_seq_len,
max_seqlen_k=self.infer_state.max_kv_seq_len,
num_heads=model.config["num_attention_heads"] // model.tp_world_size_,
num_heads_k=model.tp_k_head_num_,
headdim=model.head_dim_,
cache_seqlens=self.b_att_seq_len,
qkv_dtype=model.data_type,
headdim_v=model.head_dim_,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
page_size=1,
causal=True,
num_splits=_DECODE_MAX_NUM_SPLITS,
pack_gqa=_DECODE_PACK_GQA,
)

def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)

Expand Down Expand Up @@ -235,6 +265,9 @@ def _normal_decode_att(
causal=True,
window_size=window_size,
softcap=0.0,
scheduler_metadata=self.scheduler_metadata,
num_splits=_DECODE_MAX_NUM_SPLITS,
pack_gqa=_DECODE_PACK_GQA,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=False,
Expand Down
106 changes: 100 additions & 6 deletions lightllm/common/basemodel/attention/flashinfer/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,82 @@
from .env_utils import set_flashinfer_envs


def _fast_plan_tensor_core_decode(
decode_wrapper,
indptr,
indices,
last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
indptr_host,
kv_lens_arr_host,
max_kv_len,
):
batch_size = len(last_page_len)
if batch_size != decode_wrapper._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
"mismatches the batch size set during initialization {}".format(
batch_size, decode_wrapper._fixed_batch_size
)
)
if len(indices) > len(decode_wrapper._paged_kv_indices_buf):
raise ValueError("The size of indices should be less than or equal to the allocated buffer")

qo_indptr_host = getattr(decode_wrapper, "_lightllm_qo_indptr_host", None)
if qo_indptr_host is None or len(qo_indptr_host) != batch_size + 1:
from flashinfer.decode import _get_range_buf

qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
decode_wrapper._lightllm_qo_indptr_host = qo_indptr_host

if indptr_host is None:
indptr_host = indptr.to("cpu")
if kv_lens_arr_host is None:
from flashinfer.decode import get_seq_lens

last_page_len_host = last_page_len.to("cpu")
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
if max_kv_len is None:
max_kv_len = max(kv_lens_arr_host).item()

decode_wrapper._batch_size = batch_size
decode_wrapper._num_qo_heads = num_qo_heads
decode_wrapper._num_kv_heads = num_kv_heads
decode_wrapper._block_tables = None
decode_wrapper._max_kv_len = max_kv_len

args = [
decode_wrapper._float_workspace_buffer,
decode_wrapper._int_workspace_buffer,
decode_wrapper._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
decode_wrapper.is_cuda_graph_enabled,
head_dim,
head_dim,
False,
-1,
]
if decode_wrapper._backend == "fa2":
args.extend([-1, False, 0])
decode_wrapper._plan_info = decode_wrapper._cached_module.plan(*args)
decode_wrapper._pos_encoding_mode = "NONE"
decode_wrapper._window_left = -1
decode_wrapper._logits_soft_cap = 0.0
decode_wrapper._sm_scale = None
decode_wrapper._rope_scale = None
decode_wrapper._rope_theta = None


class FlashInferAttBackend(BaseAttBackend):
def __init__(self, model):
set_flashinfer_envs()
Expand All @@ -25,6 +101,10 @@ def __init__(self, model):
model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id()
),
]
self.kv_starts_host_buffer = [
torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"),
torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"),
]
self.q_data_type = model.data_type
self.kv_data_type = model.data_type

Expand Down Expand Up @@ -124,11 +204,11 @@ class FlashInferDecodeAttState(BaseDecodeAttState):
kv_last_page_len_buffer: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_starts: torch.Tensor = None
kv_starts_host: torch.Tensor = None
kv_seq_lens_host: torch.Tensor = None
decode_wrapper: object = None

def init_state(self):
import flashinfer

self.backend: FlashInferAttBackend = self.backend
device = self.infer_state.input_ids.device
model = self.backend.model
Expand All @@ -154,8 +234,21 @@ def init_state(self):
self.infer_state.b_kv_start_loc,
self.infer_state.max_kv_seq_len,
self.kv_indices,
zero_output=False,
)
self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int()
if self.infer_state.b_seq_len_cpu is not None:
self.kv_seq_lens_host = self.infer_state.b_seq_len_cpu
self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][
: self.infer_state.batch_size + 1
]
self.kv_starts_host[0] = 0
torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:])
if self.infer_state.skip_decode_att_wrapper_init:
return

import flashinfer

assert self.decode_wrapper is None
self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
self.backend.workspace_buffer,
Expand All @@ -182,17 +275,18 @@ def init_state(self):

def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
self.decode_wrapper.plan(
_fast_plan_tensor_core_decode(
self.decode_wrapper,
new_state.kv_starts,
new_state.kv_indices,
new_state.kv_last_page_len_buffer,
new_state.backend.tp_q_head_num,
new_state.backend.tp_kv_head_num,
new_state.backend.head_dim,
1,
q_data_type=new_state.backend.q_data_type,
kv_data_type=new_state.backend.kv_data_type,
non_blocking=True,
new_state.kv_starts_host,
new_state.kv_seq_lens_host,
new_state.infer_state.max_kv_seq_len,
)

def decode_att(
Expand Down
8 changes: 6 additions & 2 deletions lightllm/common/basemodel/attention/flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ class MlaFlashInferDecodeAttState(BaseDecodeAttState):
decode_wrapper: object = None

def init_state(self):
import flashinfer

self.backend: MlaFlashInferAttBackend = self.backend
model = self.backend.model
device = self.infer_state.input_ids.device
Expand All @@ -144,7 +142,13 @@ def init_state(self):
self.infer_state.b_kv_start_loc,
self.infer_state.max_kv_seq_len,
self.kv_indices,
zero_output=False,
)
if self.infer_state.skip_decode_att_wrapper_init:
return

import flashinfer

assert self.decode_wrapper is None

self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
Expand Down
17 changes: 16 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0]
infer_state.b_req_idx = model_input.b_req_idx
infer_state.b_seq_len = model_input.b_seq_len
infer_state.b_seq_len_cpu = model_input.b_seq_len_cpu
infer_state.b_mtp_index = model_input.b_mtp_index
if model_input.is_prefill:
if model_input.b_ready_cache_len is not None:
Expand Down Expand Up @@ -371,6 +372,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0
)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2)
if new_model_input.b_seq_len_cpu is not None:
new_model_input.b_seq_len_cpu = F.pad(
new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2
)
new_model_input.mem_indexes = F.pad(
new_model_input.mem_indexes,
(0, padded_batch_size),
Expand Down Expand Up @@ -562,6 +567,8 @@ def _decode(
model_input=model_input, new_batch_size=infer_batch_size
)
infer_state = self._create_inferstate(model_input)
need_capture = self.graph.need_capture(infer_batch_size)
infer_state.skip_decode_att_wrapper_init = not need_capture
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
Expand All @@ -571,7 +578,7 @@ def _decode(
infer_state.init_some_extra_state(self)
infer_state.init_att_state()

if self.graph.need_capture(infer_batch_size):
if need_capture:
infer_state.is_cuda_graph = True
model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state)
else:
Expand Down Expand Up @@ -1037,6 +1044,9 @@ def autotune_layers(self):
# 控制autotune的层数,用于适配不同模型
return self.config.get("first_k_dense_replace", 0) + 1

def _autotune_extra_warmup(self):
return

@final
@torch.no_grad()
@post_empty_cache
Expand Down Expand Up @@ -1106,6 +1116,11 @@ def _autotune_warmup(self):
self.mem_manager.free_all()
gc.collect()
torch.cuda.empty_cache()
try:
self._autotune_extra_warmup()
except Exception as e:
logger.warning(f"extra autotune warmup failed: {str(e)}")
logger.exception(str(e))
self.layers_num = layer_num_bak
torch.distributed.barrier()
Autotuner.end_autotune_warmup()
Expand Down
3 changes: 3 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ModelInput:
multimodal_params: list = None
# cpu 变量
mem_indexes_cpu: torch.Tensor = None
b_seq_len_cpu: torch.Tensor = None
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
# 的一些变量
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
Expand All @@ -64,6 +65,8 @@ def to_cuda(self):
assert self.is_prefill

self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
if not self.b_seq_len.is_cuda:
self.b_seq_len_cpu = self.b_seq_len
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
if self.b_ready_cache_len is not None:
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self):
self.b_mtp_index: torch.Tensor = None

self.b_seq_len: torch.Tensor = None
self.b_seq_len_cpu: torch.Tensor = None
# max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度
self.max_cache_len: int = None
# prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度
Expand All @@ -56,6 +57,7 @@ def __init__(self):
self.return_all_prompt_logics: bool = False
self.multimodal_params: dict = None
self.is_cuda_graph: bool = False # 标记是否是cuda graph的捕获推理
self.skip_decode_att_wrapper_init: bool = False
self.dist_group: CustomProcessGroup = None

# 在microbatch overlap的运行模式下,用于标记当前 microbatch 的 index 序号
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tens
def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
raise Exception("need to impl")

def _add_residual_ffn_norm(self, input_embdings, residual, infer_state: InferStateInfo, layer_weight):
add_rmsnorm = getattr(layer_weight.ffn_norm_weight_, "add_rmsnorm", None)
if add_rmsnorm is None:
input_embdings.add_(residual.view(-1, self.embed_dim_))
return self._ffn_norm(input_embdings, infer_state, layer_weight)
return add_rmsnorm(
input=input_embdings,
residual=residual.view(-1, self.embed_dim_),
eps=self.eps_,
alloc_func=self.alloc_tensor,
)

def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
Expand Down Expand Up @@ -89,10 +101,9 @@ def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, l
def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.token_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._add_residual_ffn_norm(input_embdings, o, infer_state, layer_weight)
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)

input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
Expand Down
Loading
Loading