Skip to content

Commit 401a0fd

Browse files
committed
issue/394 - feat: support flash-attn via MooreThreads/mate for moore gpu
1 parent b2eccc2 commit 401a0fd

2 files changed

Lines changed: 37 additions & 30 deletions

File tree

csrc/pybind11/engine/engine.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ inline void bind_infer_engine(py::module &m) {
7070
return state_dict_tp_all;
7171
})
7272
.def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)")
73-
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
74-
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
73+
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); },
74+
"Run inference on all ranks with arbitrary arguments",
75+
py::call_guard<py::gil_scoped_release>())
76+
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); },
77+
py::arg("cache_config") = py::none(),
78+
py::call_guard<py::gil_scoped_release>())
79+
7580
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
7681
auto cfg = self.get_cache_config();
7782
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })

examples/bench.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -325,49 +325,51 @@ def run(
325325
# Warmup
326326
# ---------------------------------------------------------------------------- #
327327
if cfg.warmup:
328-
warmup_steps = 1
328+
print("=================== warmup start ===================")
329329

330-
# warmup cache capacity
331-
warmup_cache_len = 128
332-
warmup_batch = len(test.input_ids_list)
330+
import time
331+
# time.sleep(5) # ← 制造 2 秒空白,msys Timeline 里清晰可见
333332

334-
test.model.reset_cache(
335-
StaticKVCacheConfig(
336-
max_batch_size=warmup_batch,
337-
max_cache_len=warmup_cache_len,
338-
)
339-
)
340-
341-
avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list))
333+
# 使用实际的 cache_config 触发 mate JIT
334+
if cache_config is not None:
335+
test.model.reset_cache(cache_config)
342336

343-
warmup_ids = [
344-
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
345-
for ids in test.input_ids_list
346-
]
337+
# 收集所有不同的 (batch_size, input_len) shape
338+
warmup_shapes = []
339+
seen = set()
340+
for _, case in cases_dict.items():
341+
key = (case["batch_size"], case["input_len"])
342+
if key in seen:
343+
continue
344+
seen.add(key)
345+
warmup_shapes.append((case["batch_size"], case["input_len"]))
346+
347+
for w_batch, w_input_len in warmup_shapes:
348+
tqdm.write(
349+
f"\033[93m[warmup] batch={w_batch}, input_len={w_input_len}, "
350+
f"will prefill + 3 decode steps\033[0m"
351+
)
347352

348-
input_ids_infini = infinicore.from_list(warmup_ids)
353+
warmup_ids = repeat_prompt(test.input_ids_list[0], target_length=w_input_len)
354+
warmup_ids_list = [warmup_ids] * w_batch
355+
warmup_input = infinicore.from_list(warmup_ids_list)
349356

350-
print("=================== warmup start ===================")
351-
352-
for _ in range(warmup_steps):
353357
_ = test.model.generate(
354-
input_ids_infini,
358+
warmup_input,
355359
GenerationConfig(
356-
max_new_tokens=5, # decode kernel warmup
357-
temperature=cfg.temperature,
360+
max_new_tokens=3,
361+
eos_token_id=[],
358362
top_k=cfg.top_k,
359363
top_p=cfg.top_p,
364+
temperature=cfg.temperature,
360365
stop_on_eos=False,
361366
),
362367
_measure_and_log_time=False,
363368
)
364-
365-
print("=================== warmup done ====================")
366-
367-
# reset cache back to benchmark config
368369
if cache_config is not None:
369370
test.model.reset_cache(cache_config)
370-
371+
# time.sleep(5) # ← 制造 2 秒空白,msys Timeline 里清晰可见
372+
print("=================== warmup done ====================")
371373
# ---------------------------------------------------------------------------- #
372374
# Warmup done
373375
# ---------------------------------------------------------------------------- #

0 commit comments

Comments
 (0)