Skip to content

Commit 4a83db6

Browse files
committed
issue/394 - feat: support flash-attn via MooreThreads/mate for moore gpu
1 parent b2eccc2 commit 4a83db6

2 files changed

Lines changed: 67 additions & 27 deletions

File tree

csrc/pybind11/engine/engine.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ inline void bind_infer_engine(py::module &m) {
7070
return state_dict_tp_all;
7171
})
7272
.def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)")
73-
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
74-
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
73+
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); },
74+
"Run inference on all ranks with arbitrary arguments",
75+
py::call_guard<py::gil_scoped_release>())
76+
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); },
77+
py::arg("cache_config") = py::none(),
78+
py::call_guard<py::gil_scoped_release>())
79+
7580
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
7681
auto cfg = self.get_cache_config();
7782
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })

examples/bench.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -325,48 +325,83 @@ def run(
325325
# Warmup
326326
# ---------------------------------------------------------------------------- #
327327
if cfg.warmup:
328-
warmup_steps = 1
329-
330-
# warmup cache capacity
331-
warmup_cache_len = 128
332-
warmup_batch = len(test.input_ids_list)
333-
334-
test.model.reset_cache(
335-
StaticKVCacheConfig(
336-
max_batch_size=warmup_batch,
337-
max_cache_len=warmup_cache_len,
328+
print("=================== warmup start ===================")
329+
# -------------------------------------------------------- #
330+
# reset cache before warmup
331+
# support both paged cache and static cache
332+
# -------------------------------------------------------- #
333+
if cache_config is not None:
334+
# Paged KVCache
335+
test.model.reset_cache(cache_config)
336+
else:
337+
# Static KVCache
338+
max_batch_size = max(c["batch_size"] for _, c in cases_dict.items())
339+
max_cache_len = max(
340+
c["input_len"] + c["output_len"]
341+
for _, c in cases_dict.items()
338342
)
339-
)
340343

341-
avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list))
342-
343-
warmup_ids = [
344-
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
345-
for ids in test.input_ids_list
346-
]
344+
test.model.reset_cache(
345+
StaticKVCacheConfig(
346+
max_batch_size=max_batch_size,
347+
max_cache_len=max_cache_len,
348+
)
349+
)
347350

348-
input_ids_infini = infinicore.from_list(warmup_ids)
351+
warmup_shapes = []
352+
seen = set()
353+
for _, case in cases_dict.items():
354+
key = (case["batch_size"], case["input_len"])
355+
if key in seen:
356+
continue
357+
seen.add(key)
358+
warmup_shapes.append((case["batch_size"], case["input_len"]))
359+
360+
for w_batch, w_input_len in warmup_shapes:
361+
tqdm.write(
362+
f"\033[93m[warmup] batch={w_batch}, input_len={w_input_len}, "
363+
f"will prefill + 3 decode steps\033[0m"
364+
)
349365

350-
print("=================== warmup start ===================")
366+
warmup_ids = repeat_prompt(test.input_ids_list[0], target_length=w_input_len)
367+
warmup_ids_list = [warmup_ids] * w_batch
368+
warmup_input = infinicore.from_list(warmup_ids_list)
351369

352-
for _ in range(warmup_steps):
353370
_ = test.model.generate(
354-
input_ids_infini,
371+
warmup_input,
355372
GenerationConfig(
356-
max_new_tokens=5, # decode kernel warmup
357-
temperature=cfg.temperature,
373+
max_new_tokens=3,
374+
eos_token_id=[],
358375
top_k=cfg.top_k,
359376
top_p=cfg.top_p,
377+
temperature=cfg.temperature,
360378
stop_on_eos=False,
361379
),
362380
_measure_and_log_time=False,
363381
)
364382

365383
print("=================== warmup done ====================")
366-
367-
# reset cache back to benchmark config
384+
# -------------------------------------------------------- #
385+
# reset cache back to benchmark config
386+
# support both paged cache and static cache
387+
# -------------------------------------------------------- #
368388
if cache_config is not None:
389+
# Paged KVCache
369390
test.model.reset_cache(cache_config)
391+
else:
392+
# Static KVCache
393+
max_batch_size = max(c["batch_size"] for _, c in cases_dict.items())
394+
max_cache_len = max(
395+
c["input_len"] + c["output_len"]
396+
for _, c in cases_dict.items()
397+
)
398+
399+
test.model.reset_cache(
400+
StaticKVCacheConfig(
401+
max_batch_size=max_batch_size,
402+
max_cache_len=max_cache_len,
403+
)
404+
)
370405

371406
# ---------------------------------------------------------------------------- #
372407
# Warmup done

0 commit comments

Comments
 (0)