Skip to content

feat(torch): expose optional codegen parameters#619

Open
voltjia wants to merge 2 commits into
masterfrom
feat/torch-codegen-optional-overloads
Open

feat(torch): expose optional codegen parameters#619
voltjia wants to merge 2 commits into
masterfrom
feat/torch-codegen-optional-overloads

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented May 20, 2026

Summary

  • Expose supported ATen optional parameters as stable InfiniOps C++ parameters in generated PyTorch operator bases.
  • Bind generated PyTorch backends to existing src/base/<op>.h overloads when available, forwarding omitted optional/default ATen parameters as typed defaults.
  • Add std::optional<T> support to operator cache hashing and update the generated torch-op test harness for optional arguments and known vendor-specific PyTorch crashes/divergences.
  • Add generator tests for optional parameter exposure and existing-base overload binding.

Motivation

The PyTorch code generator previously hid optional ATen schema parameters and always forwarded typed nullopt values. That made generated APIs unable to exercise non-default optional behavior and caused drift against operator base headers that intentionally expose optional parameters. This PR makes optional schema handling explicit while keeping existing hand-written bases as the public API source of truth when they are present.

Closes # N/A — this is follow-up work from the PyTorch codegen/base drift discussion.

Type of Change

  • feat — new feature / new operator / new platform
  • N/A — fix — bug fix.
  • N/A — perf — performance improvement (no behavioral change).
  • N/A — refactor — code restructuring without behavior change.
  • N/A — test — adding or fixing tests only.
  • N/A — docs — documentation only.
  • N/A — build / ci — build system or CI configuration.
  • N/A — chore — tooling, formatting, or other non-code changes.
  • N/A — Breaking change.

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • N/A — Build system / CMake / CI; no CMake or CI files are changed.
  • Python bindings / user-facing API

Test Results on Supported Platforms

All runs were rebased on current master, generated PyTorch operator sources before build, installed with WITH_TORCH=ON, and ran full verbose pytest as python3 -m pytest -v without tests/, --devices, or -n.

Platform Built pytest Result Build Time Test Time Total Time Notes / Hardware
NVIDIA Yes 6303 passed, 11538 skipped 1009s 341s 1350s PyTorch backend compiled and generated torch-op tests were included. Smoke confirmed generated classes and active implementations for representative torch ops.
Iluvatar Yes 4803 passed, 11520 skipped 798s 533s 1331s PyTorch backend compiled and generated torch-op tests were included. Smoke confirmed generated classes and active implementations for representative torch ops.
MetaX Yes 5803 passed, 10520 skipped 1392s 365s 1757s PyTorch backend compiled and generated torch-op tests were included. Smoke confirmed generated classes and active implementations for representative torch ops.
Cambricon Yes 3081 passed, 12858 skipped 2216s 922s 3138s PyTorch backend compiled and generated torch-op tests were included. Smoke confirmed generated classes and active implementations for representative torch ops.
Moore Yes 5767 passed, 10574 skipped 2194s 572s 2766s PyTorch backend compiled and generated torch-op tests were included. Smoke confirmed generated classes and active implementations for representative torch ops.
Ascend Yes 4480 passed, 11801 skipped 1112s 594s 1706s PyTorch backend compiled and generated torch-op tests were included. NPU availability was confirmed before build, and smoke confirmed active implementations for representative torch ops. The container exited with code 137 after pytest had already emitted a passing summary.
Validation details
python3 scripts/generate_torch_ops.py
python3 -m pip install --no-build-isolation --no-deps . \
  --config-settings=cmake.define.AUTO_DETECT_BACKENDS=OFF \
  --config-settings=cmake.define.WITH_CPU=ON \
  --config-settings=cmake.define.WITH_<PLATFORM>=ON \
  --config-settings=cmake.define.WITH_TORCH=ON
python3 -m pytest -v

Representative smoke checks after install confirmed generated PyTorch operator classes and active platform implementations for InternalSoftmax, LinalgDet, ClampMax, and SpecialPsi on every supported platform.

The test counts differ from earlier PR-body snapshots because this branch was rebased after splitting generated operator-base files into PR #622 and after switching generation to the local installed PyTorch native_functions.yaml. PyTorch-backed tests are still collected and executed on every platform.

python -m ruff format --check scripts/generate_torch_ops.py scripts/generate_wrappers.py tests/test_generate_torch_ops.py tests/test_generate_wrappers.py tests/test_torch_ops.py
python -m ruff check scripts/generate_torch_ops.py scripts/generate_wrappers.py tests/test_generate_torch_ops.py tests/test_generate_wrappers.py tests/test_torch_ops.py
clang-format --dry-run --Werror src/hash.h

All checks passed on the rebased branch.

Benchmark / Performance Impact

N/A — this PR changes generated API/backend plumbing and tests. The table above records build and test wall time for each platform to support follow-up compile-time optimization work.

Notes for Reviewers

  • Downstream PR feat(torch): add generated operator bases #622 was regenerated from this PR after the public C++ parameter-name fix (self remains an ATen schema name internally, while generated public C++ signatures use input) and passed full-platform validation with WITH_TORCH=ON. Those results are recorded on PR feat(torch): add generated operator bases #622 to avoid mixing downstream generated-base changes into this PR's own table.

  • Existing src/base/<op>.h overloads are treated as the public API when present. The generator binds compatible overloads to ATen schema parameters and fills omitted optional/default schema parameters at the ATen call site.

  • Generated fresh bases now expose supported optional types as std::optional<...>. PyTorch-internal optional types without stable InfiniOps representations remain hidden and are forwarded as typed empty optionals.

  • The generator reads the locally installed PyTorch torchgen packaged native_functions.yaml, so generated op availability follows the PyTorch schema available in the build environment.

  • The test harness skips only known vendor-kernel crashes/divergences that otherwise terminate the Python process or compare mismatched vendor paths; PyTorch-backed tests are still collected and executed on every platform.


Checklist

Title, Branch, and Commits

  • PR title follows Conventional Commits (e.g. feat(nvidia): …, fix(cuda/gemm): …).
  • Branch name follows <type>/xxx-yyyy-zzzz where <type> matches the PR title's Conventional Commits type and words are joined with hyphens (see CONTRIBUTING.md §Branches).
  • Each commit message follows Conventional Commits.
  • Small PR is a single squashable commit; or, for a large PR, every commit is meaningful, well-formed, and independently reviewable (see CONTRIBUTING.md §Pull Requests).
  • No stray merge commits from master — the branch is rebased cleanly on top of the current master.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal — nothing unrelated to the stated motivation was added (CONTRIBUTING.md §Code/General).
  • No dead code, commented-out blocks, debug prints, printf/std::cout/print(...) left behind, or TODO without an owner and issue link.
  • No unrelated formatting churn that would obscure the diff.
  • Public API changes are intentional, documented in this PR, and reflected in affected callers/tests.

General Code Hygiene

  • The code is self-explanatory; comments were added only where the why is non-obvious (CONTRIBUTING.md §Code/General).
  • Every modified or added file ends with a single trailing newline (CONTRIBUTING.md §Code/General).
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • Identifiers in comments and error messages are wrapped in backticks (e.g. the `seqlens_k` tensor) (CONTRIBUTING.md §Code/General).
  • All comments and error messages are in English (CONTRIBUTING.md §Code/General).
  • Comments and error messages are complete sentences — capitalized first letter, terminal punctuation — unless the language/framework convention says otherwise (CONTRIBUTING.md §Code/General; §Python).

C++ Specific

  • Code follows the Google C++ Style Guide strictly.
  • clang-format --dry-run --Werror src/hash.h passes.
  • N/A — clang-tidy was not run; no kernel or algorithm implementation path is added.
  • Operator parameter order is inputs first, outputs last; attributes are between inputs and outputs; naming follows PyTorch → ONNX → CUDA API precedence (CONTRIBUTING.md §C++).
  • No exceptions are thrown. No new C++ error path was added.
  • N/A — No new C++ error or warning message was added.
  • N/A — No kernel files are added or renamed.
  • N/A — No kernel launcher files are added or changed.
  • Constructor initializer list order matches member declaration order (CONTRIBUTING.md §C++).
  • Exactly one blank line between classes, between classes and functions, and between functions (CONTRIBUTING.md §C++).
  • Exactly one blank line between members within a class (CONTRIBUTING.md §C++).
  • Exactly one blank line before and after the contents of a namespace (CONTRIBUTING.md §C++).
  • N/A — No new hand-written operator implementation is added under src/base/<op>.h or platform implementation directories.
  • No raw new/delete; RAII / smart pointers / existing allocators are used.

Python Specific

  • Code is PEP 8 compliant; ruff check passes cleanly.
  • ruff format --check passes cleanly.
  • Comments are complete English sentences, starting with a capital letter and ending with punctuation; Markdown backticks are used for code references (CONTRIBUTING.md §Python).
  • Framework-specific conventions are honored where applicable (CONTRIBUTING.md §Python).
  • No blank line between the function signature and the body when there is no docstring or comment (CONTRIBUTING.md §Python).
  • A blank line is present before and after if, for, and similar control-flow statements (CONTRIBUTING.md §Python).
  • A blank line appears before each return, except when it directly follows a control-flow statement like if or for (CONTRIBUTING.md §Python).
  • Docstrings follow PEP 257 conventions.
  • Type hints are added / kept consistent with the surrounding code.

Testing

  • Full-platform pytest was run on all supported platforms with WITH_TORCH=ON.
  • N/A — No platform was unreachable.
  • New functionality has matching tests under tests/.
  • Tests use pytest.mark.parametrize correctly.
  • N/A — pytest.mark.auto_act_and_assert is not used by the generator unit tests or generated torch-op harness touched here.
  • Default dtype / device parameterization is relied on, or overridden with an explicit pytest.mark.parametrize when necessary.
  • Known vendor-kernel crashes/divergences are skipped explicitly to keep the full run progressing.
  • N/A — This is a feature PR rather than a bug-fix regression test PR.

Build, CI, and Tooling

  • The project builds cleanly from a fresh directory on affected platforms.
  • compile_commands.json still regenerates through the existing CMake/scikit-build configuration path.
  • N/A — No new backend or device auto-detection is added.
  • Only one CUDA-like GPU backend is selectable at a time — the existing mutual-exclusion check in CMakeLists.txt is not changed.
  • ruff and clang-format checks are green.
  • No new runtime dependency was added without updating pyproject.toml's [project.optional-dependencies].

Documentation

  • N/A — No README, CONTRIBUTING, build flag, or developer workflow change is introduced.
  • N/A — No new operator, dispatch helper, or public utility is added outside generated code behavior.
  • N/A — No user-visible breaking change is intentionally introduced.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, IP addresses, or personal hardware identifiers have been committed or included in this PR description.
  • N/A — No third-party code is added.
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks were introduced.

@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch 5 times, most recently from 5e043a8 to d9714e7 Compare May 20, 2026 13:33
@voltjia voltjia marked this pull request as ready for review May 20, 2026 14:13
@voltjia voltjia requested a review from a team May 20, 2026 14:13
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from d9714e7 to d5b04cc Compare May 23, 2026 00:27
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented May 23, 2026

Generated source/header archive for review:

  • Local workspace path: ci-results/pr619-generated-sources.tar.gz
  • Contents: generated/torch_ops_metadata.json, generated/torch/<op>/{<op>.h,<op>.cc}, and generated torch operator sources/headers.
  • Generation command used in a PyTorch-enabled container: rm -rf generated && python scripts/generate_torch_ops.py && tar -czf pr619-generated-sources.tar.gz generated

gh/GitHub REST does not provide a safe PR binary-attachment upload path, so I did not publish this archive to an external gist or unrelated branch.

@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from d5b04cc to 109b72f Compare May 23, 2026 12:51
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented May 23, 2026

Additional compatibility validation for the latest commit (109b72f0):

  • Overlaid the current feat/torch-operator-bases src/base/ headers onto this PR branch and regenerated the PyTorch backend.
  • Generation completed: 603 overloads across 494 ops; incompatible existing base overloads are skipped with explicit warnings instead of generating mismatched code.
  • tests/test_generate_torch_ops.py: 6 passed.
  • ruff format --check and ruff check passed for the touched generator/test files.
  • WITH_TORCH=ON wheel build succeeded with the overlaid operator bases.
  • Installed the generated wheel and confirmed Std.active_implementation_indices("cpu") == [8] and BatchNormElemt.active_implementation_indices("cpu") == [8].
  • Direct smoke call through the PyTorch slot succeeded: ops.std(..., implementation_index=8) matched torch.std(...) with max error 0.0.

Remaining skipped existing base headers in the overlay are schema/name compatibility warnings rather than build failures: orgqr, ormqr, prod, quantile, slow_conv3d, slow_conv3d_forward, slow_conv_transpose2d, slow_conv_transpose3d, sort, and several upsample backward overloads. These appear to be base/schema drift cases and can be handled as follow-up base updates if those exact bases need generated Torch implementations.

@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from 109b72f to fd11775 Compare May 26, 2026 13:34
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented May 26, 2026

Rebased onto current master and pushed fd117752.

Additional validation after rebase:

  • python3 -m py_compile scripts/generate_torch_ops.py scripts/generate_wrappers.py tests/test_generate_torch_ops.py tests/test_generate_wrappers.py tests/test_torch_ops.py passed locally.
  • git diff --check passed.
  • Remote container: ruff format --check and ruff check passed for touched Python files.
  • Remote container: python -m pytest tests/test_generate_torch_ops.py tests/test_generate_wrappers.py -q -> 7 passed.
  • Remote container: WITH_TORCH=ON wheel build passed.
  • Smoke test from built wheel passed for both generated clamp optional overload families:
    • scalar optional min/max: ops.clamp(x, -1.0, 1.0, out, implementation_index=8) matched torch.clamp.
    • tensor optional min/max: ops.clamp(x, lo, hi, out, implementation_index=8) matched torch.clamp.
    • Clamp.active_implementation_indices("cpu") == [8].

The rebase exposed a wrapper dispatch generation issue for overloads that reuse the same optional parameter names across scalar and Tensor variants (clamp/clip). This push fixes generated dispatch to keep optional scalar and optional Tensor overloads distinct, with a regression test in tests/test_generate_wrappers.py.

@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented May 27, 2026

Latest operator-bases overlay validation for fd117752:

  • Base branch overlaid: origin/feat/torch-operator-bases at 4c1bbcc9 (src/base/ only, 436 files).
  • Generation completed: 603 overloads across 494 ops, 1061 generated files.
  • Generator tests in the overlay passed: tests/test_generate_torch_ops.py tests/test_generate_wrappers.py -q -> 7 passed.
  • WITH_TORCH=ON wheel build passed from the overlay.
  • Smoke test from the overlay-built wheel passed:
    • Std.active_implementation_indices("cpu") == [8]
    • BatchNormElemt.active_implementation_indices("cpu") == [8]
    • ClampMax.active_implementation_indices("cpu") == [8]
    • ops.std(..., implementation_index=8) matched torch.std with max error 0.0.
    • ops.clamp_max scalar and Tensor overloads both matched torch.clamp_max.

Remaining generation warnings are existing base/schema drift cases that are skipped rather than emitted as broken code: orgqr, ormqr, prod, quantile, slow_conv3d, slow_conv3d_forward, slow_conv_transpose2d, slow_conv_transpose3d, sort, and several upsample backward overloads.

@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from fd11775 to 9444f9c Compare May 27, 2026 13:53
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch 3 times, most recently from c0db647 to 3e3e319 Compare May 27, 2026 21:15
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch 3 times, most recently from 9f591db to 70094a1 Compare May 28, 2026 07:41
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch 2 times, most recently from 327c65e to 87e86ab Compare May 28, 2026 11:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant