Skip to content

Commit 455697b

Browse files
author
zhangyue
committed
feat(bindings): support scalar default wrapper generation
1 parent 8720c39 commit 455697b

2 files changed

Lines changed: 145 additions & 4 deletions

File tree

scripts/generate_wrappers.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def _get_compilers():
5353
system_include_flags = _get_system_include_flags()
5454

5555
index = clang.cindex.Index.create()
56-
args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags)
57-
translation_unit = index.parse(f"src/base/{op_name}.h", args=args)
56+
args = ("-std=c++17", "-x", "c++", "-I", str(_SRC_DIR)) + tuple(
57+
system_include_flags
58+
)
59+
translation_unit = index.parse(str(_BASE_DIR / f"{op_name}.h"), args=args)
5860

5961
nodes = tuple(type(self)._find(translation_unit.cursor, op_name))
6062

@@ -112,9 +114,31 @@ def _find_vector_tensor_params(op_name):
112114
return set(re.findall(r"std::vector<Tensor>\s+(\w+)", source))
113115

114116

117+
def _find_params_with_defaults(op_name):
118+
"""Return `{param_name: default_literal}` for scalar params with defaults.
119+
120+
`libclang`'s cursor API does not expose defaults reliably, so we regex-scan
121+
the source. Only used for plain scalar defaults such as
122+
`bool pre_gathered = false`.
123+
"""
124+
source = (_BASE_DIR / f"{op_name}.h").read_text()
125+
126+
mapping = {}
127+
128+
for name, default in re.findall(
129+
r"\b(?:bool|int(?:64_t|32_t|8_t|16_t)?|std::size_t|std::uint\w+_t|"
130+
r"float|double)\s+(\w+)\s*=\s*([^,\)]+?)\s*(?:,|\))",
131+
source,
132+
):
133+
mapping[name] = default.strip()
134+
135+
return mapping
136+
137+
115138
def _generate_pybind11(operator):
116139
optional_tensor_params = _find_optional_tensor_params(operator.name)
117140
vector_tensor_params = _find_vector_tensor_params(operator.name)
141+
params_with_defaults = _find_params_with_defaults(operator.name)
118142

119143
def _is_optional_tensor(arg):
120144
if arg.spelling in optional_tensor_params:
@@ -186,6 +210,10 @@ def _generate_py_args(node):
186210

187211
if _is_optional(arg):
188212
parts.append(f'py::arg("{arg.spelling}") = py::none()')
213+
elif arg.spelling in params_with_defaults:
214+
parts.append(
215+
f'py::arg("{arg.spelling}") = {params_with_defaults[arg.spelling]}'
216+
)
189217
else:
190218
parts.append(f'py::arg("{arg.spelling}")')
191219

@@ -257,8 +285,7 @@ def _generate_call(op_name, call, method=True):
257285
}})
258286
.def_static("clear_cache", &Self::clear_cache);
259287
260-
{callers}
261-
}}
288+
{callers}}}
262289
263290
}} // namespace infini::ops
264291

tests/test_generate_wrappers.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import functools
2+
import importlib.util
3+
from pathlib import Path
4+
5+
import pytest
6+
7+
8+
pytest.importorskip("clang.cindex")
9+
10+
11+
@functools.lru_cache(maxsize=1)
12+
def _load_generator():
13+
script = Path(__file__).parents[1] / "scripts" / "generate_wrappers.py"
14+
spec = importlib.util.spec_from_file_location("generate_wrappers", script)
15+
module = importlib.util.module_from_spec(spec)
16+
spec.loader.exec_module(module)
17+
18+
return module
19+
20+
21+
def _generate_binding(op_name, tmp_path, monkeypatch, source):
22+
generator = _load_generator()
23+
src_dir = tmp_path / "src"
24+
base_dir = src_dir / "base"
25+
base_dir.mkdir(parents=True)
26+
(base_dir / f"{op_name}.h").write_text(source)
27+
monkeypatch.setattr(generator, "_SRC_DIR", src_dir)
28+
monkeypatch.setattr(generator, "_BASE_DIR", base_dir)
29+
operator = generator._OperatorExtractor()(op_name)
30+
31+
return generator._generate_pybind11(operator)
32+
33+
34+
def test_mha_varlen_fwd_requires_out_binding(tmp_path, monkeypatch):
35+
text = _generate_binding(
36+
"mha_varlen_fwd",
37+
tmp_path,
38+
monkeypatch,
39+
"""
40+
#include <cstdint>
41+
#include <optional>
42+
43+
namespace infini::ops {
44+
45+
struct Tensor {};
46+
47+
template <typename T>
48+
class Operator {};
49+
50+
class MhaVarlenFwd : public Operator<MhaVarlenFwd> {
51+
public:
52+
MhaVarlenFwd(const Tensor q, const Tensor k, const Tensor v, Tensor out,
53+
const Tensor cu_seqlens_q, const Tensor cu_seqlens_k,
54+
std::optional<Tensor> block_table, float softmax_scale,
55+
bool is_causal, int64_t num_splits = 0) {}
56+
57+
virtual void operator()(const Tensor q, const Tensor k, const Tensor v,
58+
Tensor out, const Tensor cu_seqlens_q,
59+
const Tensor cu_seqlens_k,
60+
std::optional<Tensor> block_table,
61+
float softmax_scale, bool is_causal,
62+
int64_t num_splits = 0) const = 0;
63+
};
64+
65+
} // namespace infini::ops
66+
""",
67+
)
68+
69+
assert 'py::arg("out"), py::arg("cu_seqlens_q")' in text
70+
assert 'py::arg("num_splits") = 0' in text
71+
assert 'py::arg("out") = py::none()' not in text
72+
assert "std::optional<py::object> out" not in text
73+
assert "OptionalTensorFromPybind11Handle(out)" not in text
74+
75+
76+
def test_mha_fwd_kvcache_requires_out_binding(tmp_path, monkeypatch):
77+
text = _generate_binding(
78+
"mha_fwd_kvcache",
79+
tmp_path,
80+
monkeypatch,
81+
"""
82+
#include <cstdint>
83+
#include <optional>
84+
85+
namespace infini::ops {
86+
87+
struct Tensor {};
88+
89+
template <typename T>
90+
class Operator {};
91+
92+
class MhaFwdKvcache : public Operator<MhaFwdKvcache> {
93+
public:
94+
MhaFwdKvcache(const Tensor q, const Tensor kcache, const Tensor vcache,
95+
std::optional<Tensor> k, std::optional<Tensor> v, Tensor out,
96+
float softmax_scale, bool is_causal,
97+
int64_t num_splits = 0) {}
98+
99+
virtual void operator()(const Tensor q, const Tensor kcache,
100+
const Tensor vcache, std::optional<Tensor> k,
101+
std::optional<Tensor> v, Tensor out,
102+
float softmax_scale, bool is_causal,
103+
int64_t num_splits = 0) const = 0;
104+
};
105+
106+
} // namespace infini::ops
107+
""",
108+
)
109+
110+
assert 'py::arg("out"), py::arg("softmax_scale")' in text
111+
assert 'py::arg("num_splits") = 0' in text
112+
assert 'py::arg("out") = py::none()' not in text
113+
assert "std::optional<py::object> out" not in text
114+
assert "OptionalTensorFromPybind11Handle(out)" not in text

0 commit comments

Comments
 (0)