Skip to content

Commit 11a74b7

Browse files
author
zhangyue
committed
feat(bindings): support scalar default wrapper generation
1 parent 64751ea commit 11a74b7

2 files changed

Lines changed: 141 additions & 1 deletion

File tree

scripts/generate_wrappers.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __call__(self, op_name):
8282
"-x",
8383
"c++",
8484
"-I",
85-
"src",
85+
str(_SRC_DIR),
8686
"-I",
8787
str(_GENERATION_DIR),
8888
) + _get_system_include_flags()
@@ -160,10 +160,32 @@ def _find_vector_int64_params(op_name):
160160
return set(re.findall(r"std::vector<int64_t>\s+(\w+)", source))
161161

162162

163+
def _find_params_with_defaults(op_name):
164+
"""Return `{param_name: default_literal}` for scalar params with defaults.
165+
166+
`libclang`'s cursor API does not expose defaults reliably, so we regex-scan
167+
the source. Only used for plain scalar defaults such as
168+
`bool pre_gathered = false`.
169+
"""
170+
source = _find_base_header(op_name).read_text()
171+
172+
mapping = {}
173+
174+
for name, default in re.findall(
175+
r"\b(?:bool|int(?:64_t|32_t|8_t|16_t)?|std::size_t|std::uint\w+_t|"
176+
r"float|double)\s+(\w+)\s*=\s*([^,\)]+?)\s*(?:,|\))",
177+
source,
178+
):
179+
mapping[name] = default.strip()
180+
181+
return mapping
182+
183+
163184
def _generate_pybind11(operator):
164185
optional_tensor_params = _find_optional_tensor_params(operator.name)
165186
vector_tensor_params = _find_vector_tensor_params(operator.name)
166187
vector_int64_params = _find_vector_int64_params(operator.name)
188+
params_with_defaults = _find_params_with_defaults(operator.name)
167189

168190
def _is_optional_tensor(arg):
169191
if arg.spelling in optional_tensor_params:
@@ -242,6 +264,10 @@ def _generate_py_args(node):
242264

243265
if _is_optional(arg):
244266
parts.append(f'py::arg("{arg.spelling}") = py::none()')
267+
elif arg.spelling in params_with_defaults:
268+
parts.append(
269+
f'py::arg("{arg.spelling}") = {params_with_defaults[arg.spelling]}'
270+
)
245271
else:
246272
parts.append(f'py::arg("{arg.spelling}")')
247273

tests/test_generate_wrappers.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import functools
2+
import importlib.util
3+
from pathlib import Path
4+
5+
import pytest
6+
7+
8+
pytest.importorskip("clang.cindex")
9+
10+
11+
@functools.lru_cache(maxsize=1)
12+
def _load_generator():
13+
script = Path(__file__).parents[1] / "scripts" / "generate_wrappers.py"
14+
spec = importlib.util.spec_from_file_location("generate_wrappers", script)
15+
module = importlib.util.module_from_spec(spec)
16+
spec.loader.exec_module(module)
17+
18+
return module
19+
20+
21+
def _generate_binding(op_name, tmp_path, monkeypatch, source):
22+
generator = _load_generator()
23+
src_dir = tmp_path / "src"
24+
base_dir = src_dir / "base"
25+
base_dir.mkdir(parents=True)
26+
(base_dir / f"{op_name}.h").write_text(source)
27+
monkeypatch.setattr(generator, "_SRC_DIR", src_dir)
28+
monkeypatch.setattr(generator, "_BASE_DIR", base_dir)
29+
operator = generator._OperatorExtractor()(op_name)
30+
31+
return generator._generate_pybind11(operator)
32+
33+
34+
def test_mha_varlen_fwd_requires_out_binding(tmp_path, monkeypatch):
35+
text = _generate_binding(
36+
"mha_varlen_fwd",
37+
tmp_path,
38+
monkeypatch,
39+
"""
40+
#include <cstdint>
41+
#include <optional>
42+
43+
namespace infini::ops {
44+
45+
struct Tensor {};
46+
47+
template <typename T>
48+
class Operator {};
49+
50+
class MhaVarlenFwd : public Operator<MhaVarlenFwd> {
51+
public:
52+
MhaVarlenFwd(const Tensor q, const Tensor k, const Tensor v, Tensor out,
53+
const Tensor cu_seqlens_q, const Tensor cu_seqlens_k,
54+
std::optional<Tensor> block_table, float softmax_scale,
55+
bool is_causal, int64_t num_splits = 0) {}
56+
57+
virtual void operator()(const Tensor q, const Tensor k, const Tensor v,
58+
Tensor out, const Tensor cu_seqlens_q,
59+
const Tensor cu_seqlens_k,
60+
std::optional<Tensor> block_table,
61+
float softmax_scale, bool is_causal,
62+
int64_t num_splits = 0) const = 0;
63+
};
64+
65+
} // namespace infini::ops
66+
""",
67+
)
68+
69+
assert 'py::arg("out"), py::arg("cu_seqlens_q")' in text
70+
assert 'py::arg("num_splits") = 0' in text
71+
assert 'py::arg("out") = py::none()' not in text
72+
assert "std::optional<py::object> out" not in text
73+
assert "OptionalTensorFromPybind11Handle(out)" not in text
74+
75+
76+
def test_mha_fwd_kvcache_requires_out_binding(tmp_path, monkeypatch):
77+
text = _generate_binding(
78+
"mha_fwd_kvcache",
79+
tmp_path,
80+
monkeypatch,
81+
"""
82+
#include <cstdint>
83+
#include <optional>
84+
85+
namespace infini::ops {
86+
87+
struct Tensor {};
88+
89+
template <typename T>
90+
class Operator {};
91+
92+
class MhaFwdKvcache : public Operator<MhaFwdKvcache> {
93+
public:
94+
MhaFwdKvcache(const Tensor q, const Tensor kcache, const Tensor vcache,
95+
std::optional<Tensor> k, std::optional<Tensor> v, Tensor out,
96+
float softmax_scale, bool is_causal,
97+
int64_t num_splits = 0) {}
98+
99+
virtual void operator()(const Tensor q, const Tensor kcache,
100+
const Tensor vcache, std::optional<Tensor> k,
101+
std::optional<Tensor> v, Tensor out,
102+
float softmax_scale, bool is_causal,
103+
int64_t num_splits = 0) const = 0;
104+
};
105+
106+
} // namespace infini::ops
107+
""",
108+
)
109+
110+
assert 'py::arg("out"), py::arg("softmax_scale")' in text
111+
assert 'py::arg("num_splits") = 0' in text
112+
assert 'py::arg("out") = py::none()' not in text
113+
assert "std::optional<py::object> out" not in text
114+
assert "OptionalTensorFromPybind11Handle(out)" not in text

0 commit comments

Comments
 (0)