Skip to content

Commit feac819

Browse files
author
niushengxiao
committed
fix
1 parent d59ba94 commit feac819

7 files changed

Lines changed: 301 additions & 113 deletions

File tree

docker/Dockerfile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cu
6262

6363
RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*
6464

65-
RUN pip install --no-cache-dir "flash-attn-4[13]==4.0.0b13"
66-
6765
ENV CUDA_HOME=/usr/local/cuda \
6866
GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/
6967

lightllm/common/basemodel/attention/fa4/fp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _ensure_fa4_paged_kv_supported(
5656

5757
@dataclasses.dataclass
5858
class Fa4PrefillAttState(PagedFa3PrefillAttState):
59-
def _nomarl_prefill_att(
59+
def _normal_prefill_att(
6060
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
6161
) -> torch.Tensor:
6262
if att_control.use_sliding_window:

lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ def __init__(
2626
dtype=self.linear_config.conv_state_dtype,
2727
shape=self.linear_config.get_conv_state_shape(),
2828
layer_num=self.linear_config.linear_layer_num,
29-
device="cpu",
29+
device="cuda",
3030
size_first=True,
3131
)
3232
self.ssm_state_cache = LayerCache(
3333
size=self.size,
3434
dtype=self.linear_config.ssm_state_dtype,
3535
shape=self.linear_config.get_ssm_state_shape(),
3636
layer_num=self.linear_config.linear_layer_num,
37-
device="cpu",
37+
device="cuda",
3838
size_first=True,
3939
)
4040
self.clear_to_init_state()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,4 @@ nixl==1.1.0
9898
xformers==0.0.35
9999
redis==7.3.0
100100
litellm>=1.52.0,<1.85
101+
flash-attn-4[13]==4.0.0b13

test/benchmark/service/benchmark_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ def get_tokenizer(
2727
return tokenizer
2828

2929

30+
def normalize_model_name(model_name: str) -> str:
31+
if not model_name:
32+
return model_name
33+
normalized = model_name.rstrip("/\\")
34+
return normalized or model_name
35+
36+
3037
def get_output_length(input_num: int, output_len: int) -> List[int]:
3138
min_len, max_len = 2, output_len * 2
3239
mean = (min_len + max_len) * 0.5
@@ -162,7 +169,7 @@ def main():
162169
return
163170

164171
assert args.tokenizer_path is not None
165-
model_name.append(args.tokenizer_path)
172+
model_name.append(normalize_model_name(args.tokenizer_path))
166173
seed_all(args.seed)
167174
url = args.url
168175
tokenizer = get_tokenizer(args.tokenizer_path)

0 commit comments

Comments
 (0)