Skip to content

Commit 6b2d3ba

Browse files
committed
issue/394 - feat: support flash-attn via MooreThreads/mate for moore gpu
1 parent b2eccc2 commit 6b2d3ba

2 files changed

Lines changed: 30 additions & 25 deletions

File tree

csrc/pybind11/engine/engine.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ inline void bind_infer_engine(py::module &m) {
7070
return state_dict_tp_all;
7171
})
7272
.def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)")
73-
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
74-
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
73+
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); },
74+
"Run inference on all ranks with arbitrary arguments",
75+
py::call_guard<py::gil_scoped_release>())
76+
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); },
77+
py::arg("cache_config") = py::none(),
78+
py::call_guard<py::gil_scoped_release>())
79+
7580
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
7681
auto cfg = self.get_cache_config();
7782
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })

examples/bench.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -325,38 +325,38 @@ def run(
325325
# Warmup
326326
# ---------------------------------------------------------------------------- #
327327
if cfg.warmup:
328-
warmup_steps = 1
328+
print("=================== warmup start ===================")
329329

330-
# warmup cache capacity
331-
warmup_cache_len = 128
332-
warmup_batch = len(test.input_ids_list)
330+
if cache_config is not None:
331+
test.model.reset_cache(cache_config)
333332

334-
test.model.reset_cache(
335-
StaticKVCacheConfig(
336-
max_batch_size=warmup_batch,
337-
max_cache_len=warmup_cache_len,
333+
warmup_shapes = []
334+
seen = set()
335+
for _, case in cases_dict.items():
336+
key = (case["batch_size"], case["input_len"])
337+
if key in seen:
338+
continue
339+
seen.add(key)
340+
warmup_shapes.append((case["batch_size"], case["input_len"]))
341+
342+
for w_batch, w_input_len in warmup_shapes:
343+
tqdm.write(
344+
f"\033[93m[warmup] batch={w_batch}, input_len={w_input_len}, "
345+
f"will prefill + 3 decode steps\033[0m"
338346
)
339-
)
340-
341-
avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list))
342347

343-
warmup_ids = [
344-
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
345-
for ids in test.input_ids_list
346-
]
348+
warmup_ids = repeat_prompt(test.input_ids_list[0], target_length=w_input_len)
349+
warmup_ids_list = [warmup_ids] * w_batch
350+
warmup_input = infinicore.from_list(warmup_ids_list)
347351

348-
input_ids_infini = infinicore.from_list(warmup_ids)
349-
350-
print("=================== warmup start ===================")
351-
352-
for _ in range(warmup_steps):
353352
_ = test.model.generate(
354-
input_ids_infini,
353+
warmup_input,
355354
GenerationConfig(
356-
max_new_tokens=5, # decode kernel warmup
357-
temperature=cfg.temperature,
355+
max_new_tokens=3,
356+
eos_token_id=[],
358357
top_k=cfg.top_k,
359358
top_p=cfg.top_p,
359+
temperature=cfg.temperature,
360360
stop_on_eos=False,
361361
),
362362
_measure_and_log_time=False,

0 commit comments

Comments
 (0)