Skip to content

Commit ca2d374

Browse files
committed
issue/394 - feat: support flash-attn via MooreThreads/mate for moore gpu
1 parent 4c3e266 commit ca2d374

2 files changed

Lines changed: 67 additions & 27 deletions

File tree

csrc/pybind11/engine/engine.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ inline void bind_infer_engine(py::module &m) {
7070
return state_dict_tp_all;
7171
})
7272
.def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)")
73-
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
74-
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
73+
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); },
74+
"Run inference on all ranks with arbitrary arguments",
75+
py::call_guard<py::gil_scoped_release>())
76+
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); },
77+
py::arg("cache_config") = py::none(),
78+
py::call_guard<py::gil_scoped_release>())
79+
7580
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
7681
auto cfg = self.get_cache_config();
7782
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })

examples/bench.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -328,48 +328,83 @@ def run(
328328
# Warmup
329329
# ---------------------------------------------------------------------------- #
330330
if cfg.warmup:
331-
warmup_steps = 1
332-
333-
# warmup cache capacity
334-
warmup_cache_len = 128
335-
warmup_batch = len(test.input_ids_list)
336-
337-
test.model.reset_cache(
338-
StaticKVCacheConfig(
339-
max_batch_size=warmup_batch,
340-
max_cache_len=warmup_cache_len,
331+
print("=================== warmup start ===================")
332+
# -------------------------------------------------------- #
333+
# reset cache before warmup
334+
# support both paged cache and static cache
335+
# -------------------------------------------------------- #
336+
if cache_config is not None:
337+
# Paged KVCache
338+
test.model.reset_cache(cache_config)
339+
else:
340+
# Static KVCache
341+
max_batch_size = max(c["batch_size"] for _, c in cases_dict.items())
342+
max_cache_len = max(
343+
c["input_len"] + c["output_len"]
344+
for _, c in cases_dict.items()
341345
)
342-
)
343346

344-
avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list))
345-
346-
warmup_ids = [
347-
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
348-
for ids in test.input_ids_list
349-
]
347+
test.model.reset_cache(
348+
StaticKVCacheConfig(
349+
max_batch_size=max_batch_size,
350+
max_cache_len=max_cache_len,
351+
)
352+
)
350353

351-
input_ids_infini = infinicore.from_list(warmup_ids)
354+
warmup_shapes = []
355+
seen = set()
356+
for _, case in cases_dict.items():
357+
key = (case["batch_size"], case["input_len"])
358+
if key in seen:
359+
continue
360+
seen.add(key)
361+
warmup_shapes.append((case["batch_size"], case["input_len"]))
362+
363+
for w_batch, w_input_len in warmup_shapes:
364+
tqdm.write(
365+
f"\033[93m[warmup] batch={w_batch}, input_len={w_input_len}, "
366+
f"will prefill + 3 decode steps\033[0m"
367+
)
352368

353-
print("=================== warmup start ===================")
369+
warmup_ids = repeat_prompt(test.input_ids_list[0], target_length=w_input_len)
370+
warmup_ids_list = [warmup_ids] * w_batch
371+
warmup_input = infinicore.from_list(warmup_ids_list)
354372

355-
for _ in range(warmup_steps):
356373
_ = test.model.generate(
357-
input_ids_infini,
374+
warmup_input,
358375
GenerationConfig(
359-
max_new_tokens=5, # decode kernel warmup
360-
temperature=cfg.temperature,
376+
max_new_tokens=3,
377+
eos_token_id=[],
361378
top_k=cfg.top_k,
362379
top_p=cfg.top_p,
380+
temperature=cfg.temperature,
363381
stop_on_eos=False,
364382
),
365383
_measure_and_log_time=False,
366384
)
367385

368386
print("=================== warmup done ====================")
369-
370-
# reset cache back to benchmark config
387+
# -------------------------------------------------------- #
388+
# reset cache back to benchmark config
389+
# support both paged cache and static cache
390+
# -------------------------------------------------------- #
371391
if cache_config is not None:
392+
# Paged KVCache
372393
test.model.reset_cache(cache_config)
394+
else:
395+
# Static KVCache
396+
max_batch_size = max(c["batch_size"] for _, c in cases_dict.items())
397+
max_cache_len = max(
398+
c["input_len"] + c["output_len"]
399+
for _, c in cases_dict.items()
400+
)
401+
402+
test.model.reset_cache(
403+
StaticKVCacheConfig(
404+
max_batch_size=max_batch_size,
405+
max_cache_len=max_cache_len,
406+
)
407+
)
373408

374409
# ---------------------------------------------------------------------------- #
375410
# Warmup done

0 commit comments

Comments
 (0)