Skip to content

feat: add a script to check for kernel public api changes#980

Open
drbh wants to merge 2 commits into
mainfrom
check-kernel-public-apis
Open

feat: add a script to check for kernel public api changes#980
drbh wants to merge 2 commits into
mainfrom
check-kernel-public-apis

Conversation

@drbh

@drbh drbh commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

this PR runs a new script to check if there are any changes to a kernels public api - when a kernels public api changes we should bump the version

sayakpaul
sayakpaul previously approved these changes Jun 20, 2026
Comment thread scripts/check_public_api.py Outdated


def extract_api(kernel_root: Path) -> dict:
ext = kernel_root / "torch-ext"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay this is only for PyTorch now. Maybe add a comment about it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me, added a comment to the updated function in latest changes

Comment thread scripts/check_public_api.py Outdated
if not ext.is_dir():
return {}
api = {}
for src in sorted([*ext.rglob("*.cpp"), *ext.rglob("*.cc")]):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to read the source files from the build.toml?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea good point - updated to use the build.toml list in latest changes

Comment thread scripts/check_public_api.py Outdated
schema = " ".join(match.group(1).split())
if schema:
api[f"op {schema.split('(', 1)[0].strip()}"] = schema
for py in sorted(ext.rglob("*.py")):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition above skips tvm-ffi kernels, but the Python bits would apply there as well. Seems nicer to just do them separately?

Besides that, probably only stuff exported through __all__ in __init__.py (when present) is public API, even more so because you cannot import other submodules. So, I think it should be checked what is reachable through __init__.py, otherwise we'll get a lot of false positives that are internals.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching - the code should now inspect both torch-ext and tvm-ffi-ext's init files now.

we either use the values from __all__ if available, or fallback to parse all exported values that do not start with _

Comment thread scripts/check_public_api.py Outdated

def changed_kernels(ref: str) -> list:
found = set()
for line in git("diff", "--name-only", ref).splitlines():

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing ref names directly to a command enables injection attacks, especially because refs can be user-controlled. Probably better to use something like libgit2?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch, updated to prefer using pygit2 in latest to avoid this injection risk

Comment thread scripts/check_public_api.py Outdated

# Private by name but still part of the public contract.
PUBLIC_DUNDERS = {"__init__", "__call__"}
DEF_RE = re.compile(r"\.def\(\s*\"((?:[^\"\\]|\\.)*)\"")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a C++ parser that we easily use from Python?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I've updated to use libclang to parse the cpp source into tokens where the ops.def( calls are identified and the signature is extracted

Comment thread scripts/check_public_api.py Outdated
for match in DEF_RE.finditer(src.read_text(errors="replace")):
schema = " ".join(match.group(1).split())
if schema:
api[f"op {schema.split('(', 1)[0].strip()}"] = schema

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torch has a schema parser: torch._C.parse_schema.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once the schemas are extracted in the step above they are parsed with this function in the latest changes

@drbh

drbh commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator Author

PR updated and I think its much more robust/solid now.

example usage

 uv run scripts/check_public_api.py flash-attn2

output

Public API: 303e2f7c..71bd3f8f (origin/main..HEAD)

  [ok] flash-attn2: 18 symbols
     |- op bwd = bwd(Tensor($0! -> ) dout, Tensor($1! -> ) q, Tensor($2! -...
     |- op fwd = fwd(Tensor($0! -> ) q, Tensor k, Tensor v, Tensor(out_!)?...
     |- op fwd_kvcache = fwd_kvcache(Tensor($0! -> ) q, Tensor($1! -> ) kcache, Te...
     |- op varlen_bwd = varlen_bwd(Tensor($0! -> ) dout, Tensor($1! -> ) q, Tenso...
     |- op varlen_fwd = varlen_fwd(Tensor($0! -> ) q, Tensor k, Tensor v, Tensor?...
     |- py flash_attn2 __all__ = ['bwd', 'flash_attn_func', 'flash_attn_kvpacked_func', 'f...
     `- ... and 12 more

now if you add a new attribute to one of the signature (either on the cpp or python side)

uv run scripts/check_public_api.py flash-attn2

output

Public API: 303e2f7c..71bd3f8f (origin/main..HEAD)

  [changed] flash-attn2: 1 changed
     `- changed op fwd
        - fwd(Tensor($0! -> ) q, Tensor k, Tensor v, Tensor(out_!)? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]
        + fwd(Tensor! q, Tensor k, Tensor v, Tensor(out_!)? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal,int window_size_left, int window_size_right, float softcap, bool return_softmax, bool testGenerator? gen_) -> Tensor[]

ERROR: public API changed - bump the kernel version if intentional.

NOTE:
since we prefer parsing __all__ and fallback to exposed functions - if a __all__ var is defined we expect all public functions to be included. functions that are not in a defined __all__ will not be considered as public. Therefore if a downstream user is using an exported function that is not in __all__ it will not be caught by this check.

TLDR; its important to document the public api's via __all__ in all cases.

@drbh drbh added area: github-actions GitHub Actions workflows and action versions area: repo-automation Repository scripts, bots, freshness checks, and generated automation size: L Diff <= 1000 lines type: feature New functionality / capability labels Jun 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area: github-actions GitHub Actions workflows and action versions area: repo-automation Repository scripts, bots, freshness checks, and generated automation size: L Diff <= 1000 lines type: feature New functionality / capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants