feat: add a script to check for kernel public api changes#980
Conversation
|
|
||
|
|
||
| def extract_api(kernel_root: Path) -> dict: | ||
| ext = kernel_root / "torch-ext" |
There was a problem hiding this comment.
Okay this is only for PyTorch now. Maybe add a comment about it?
There was a problem hiding this comment.
sounds good to me, added a comment to the updated function in latest changes
| if not ext.is_dir(): | ||
| return {} | ||
| api = {} | ||
| for src in sorted([*ext.rglob("*.cpp"), *ext.rglob("*.cc")]): |
There was a problem hiding this comment.
Better to read the source files from the build.toml?
There was a problem hiding this comment.
yea good point - updated to use the build.toml list in latest changes
| schema = " ".join(match.group(1).split()) | ||
| if schema: | ||
| api[f"op {schema.split('(', 1)[0].strip()}"] = schema | ||
| for py in sorted(ext.rglob("*.py")): |
There was a problem hiding this comment.
The condition above skips tvm-ffi kernels, but the Python bits would apply there as well. Seems nicer to just do them separately?
Besides that, probably only stuff exported through __all__ in __init__.py (when present) is public API, even more so because you cannot import other submodules. So, I think it should be checked what is reachable through __init__.py, otherwise we'll get a lot of false positives that are internals.
There was a problem hiding this comment.
thanks for catching - the code should now inspect both torch-ext and tvm-ffi-ext's init files now.
we either use the values from __all__ if available, or fallback to parse all exported values that do not start with _
|
|
||
| def changed_kernels(ref: str) -> list: | ||
| found = set() | ||
| for line in git("diff", "--name-only", ref).splitlines(): |
There was a problem hiding this comment.
Passing ref names directly to a command enables injection attacks, especially because refs can be user-controlled. Probably better to use something like libgit2?
There was a problem hiding this comment.
great catch, updated to prefer using pygit2 in latest to avoid this injection risk
|
|
||
| # Private by name but still part of the public contract. | ||
| PUBLIC_DUNDERS = {"__init__", "__call__"} | ||
| DEF_RE = re.compile(r"\.def\(\s*\"((?:[^\"\\]|\\.)*)\"") |
There was a problem hiding this comment.
Is there a C++ parser that we easily use from Python?
There was a problem hiding this comment.
yes, I've updated to use libclang to parse the cpp source into tokens where the ops.def( calls are identified and the signature is extracted
| for match in DEF_RE.finditer(src.read_text(errors="replace")): | ||
| schema = " ".join(match.group(1).split()) | ||
| if schema: | ||
| api[f"op {schema.split('(', 1)[0].strip()}"] = schema |
There was a problem hiding this comment.
Torch has a schema parser: torch._C.parse_schema.
There was a problem hiding this comment.
once the schemas are extracted in the step above they are parsed with this function in the latest changes
|
PR updated and I think its much more robust/solid now. example usage uv run scripts/check_public_api.py flash-attn2output Public API: 303e2f7c..71bd3f8f (origin/main..HEAD)
[ok] flash-attn2: 18 symbols
|- op bwd = bwd(Tensor($0! -> ) dout, Tensor($1! -> ) q, Tensor($2! -...
|- op fwd = fwd(Tensor($0! -> ) q, Tensor k, Tensor v, Tensor(out_!)?...
|- op fwd_kvcache = fwd_kvcache(Tensor($0! -> ) q, Tensor($1! -> ) kcache, Te...
|- op varlen_bwd = varlen_bwd(Tensor($0! -> ) dout, Tensor($1! -> ) q, Tenso...
|- op varlen_fwd = varlen_fwd(Tensor($0! -> ) q, Tensor k, Tensor v, Tensor?...
|- py flash_attn2 __all__ = ['bwd', 'flash_attn_func', 'flash_attn_kvpacked_func', 'f...
`- ... and 12 morenow if you add a new attribute to one of the signature (either on the cpp or python side) output Public API: 303e2f7c..71bd3f8f (origin/main..HEAD)
[changed] flash-attn2: 1 changed
`- changed op fwd
- fwd(Tensor($0! -> ) q, Tensor k, Tensor v, Tensor(out_!)? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]
+ fwd(Tensor! q, Tensor k, Tensor v, Tensor(out_!)? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal,int window_size_left, int window_size_right, float softcap, bool return_softmax, bool testGenerator? gen_) -> Tensor[]
ERROR: public API changed - bump the kernel version if intentional.NOTE: TLDR; its important to document the public api's via |
this PR runs a new script to check if there are any changes to a kernels public api - when a kernels public api changes we should bump the version