Skip to content

Commit eaa9569

Browse files
committed
Merge main into metal-gemm
2 parents ce6e4c2 + f3a550b commit eaa9569

468 files changed

Lines changed: 57865 additions & 10872 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
---
2+
name: tilelang-tvm-ir
3+
description: Use when editing TileLang C++ passes or TVM TIRX code that handles ObjectRef/NodeRef types such as For, Buffer, Var, SBlock, Stmt, PrimExpr, or their *Node raw node counterparts; especially when choosing function parameters, optional values, identity maps/sets, or equality checks.
4+
---
5+
6+
# TileLang TVM IR Handle Conventions
7+
8+
## Core Rule
9+
10+
In TVM C++, `For`, `Buffer`, `Var`, `SBlock`, `Stmt`, `PrimExpr`, `SeqStmt`, etc. are `ObjectRef` smart handles. `ForNode`, `BufferNode`, `VarNode`, `SBlockNode`, etc. are raw node structs reached through visitor callbacks, `as<TNode>()`, `operator->`, or `.get()`.
11+
12+
When a value needs to cross a function boundary, be stored, be optional, be used as an identity key, or survive beyond a local inspection branch, prefer the handle type over `const *Node`.
13+
14+
## Preferred Patterns
15+
16+
- Function parameters and return values: use handles such as `For`, `Buffer`, `Var`, `SBlock`, `Stmt`, or `SeqStmt`.
17+
- Nullable AST values: use `Optional<For>`, `Optional<SeqStmt>`, etc., not `const ForNode* = nullptr`.
18+
- Identity maps and sets: use handle keys with TVM identity hashing:
19+
20+
```cpp
21+
using BufferSet = std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>;
22+
using BufferMap = std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>;
23+
using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
24+
```
25+
26+
- Identity comparisons: use `.same_as(other)` when comparing two handles.
27+
- Visitor callback node pointers: convert to a handle with `GetRef<T>(op)` when the value must be retained or passed elsewhere.
28+
29+
```cpp
30+
Stmt VisitStmt_(const ForNode* op) final {
31+
For loop = GetRef<For>(op);
32+
Optional<For> candidate = FindPipelineLoop(loop->body);
33+
if (candidate.defined() && candidate.value().same_as(loop)) {
34+
...
35+
}
36+
}
37+
```
38+
39+
- Pattern matching and local mutation may still use node pointers:
40+
- `if (const auto* seq = stmt.as<SeqStmtNode>()) { ... }`
41+
- `BufferStoreNode* n = store.CopyOnWrite();`
42+
- visitor overrides such as `VisitStmt_(const SeqStmtNode* op)`
43+
44+
Keep these raw pointers local to the immediate inspection or mutation site.
45+
46+
## Avoid
47+
48+
- Passing `const ForNode*`, `const BufferNode*`, `const VarNode*`, or `const SBlockNode*` between helper functions when a handle exists.
49+
- Storing raw node pointers in `std::unordered_map` or `std::unordered_set` for identity tracking.
50+
- Using `.get()` as a key unless a callee requires a raw TVM node API and the pointer is not retained.
51+
- Comparing handles through `.get() == other.get()`; prefer `.same_as()`.
52+
- Reconstructing handles from raw pointers repeatedly when a handle is already available.
53+
54+
## Common Refactors
55+
56+
```cpp
57+
// Before
58+
const SeqStmtNode* pipeline_body_seq = nullptr;
59+
pipeline_body_seq = seq_stmt;
60+
ICHECK(pipeline_body_seq != nullptr);
61+
62+
// After
63+
Optional<SeqStmt> pipeline_body_seq;
64+
pipeline_body_seq = GetRef<SeqStmt>(seq_stmt);
65+
ICHECK(pipeline_body_seq.defined());
66+
SeqStmt pipeline_body = pipeline_body_seq.value();
67+
```
68+
69+
```cpp
70+
// Before
71+
std::unordered_set<const BufferNode*> seen;
72+
seen.insert(buffer.get());
73+
if (seen.count(read->buffer.get())) { ... }
74+
75+
// After
76+
BufferSet seen;
77+
seen.insert(buffer);
78+
if (seen.count(read->buffer)) { ... }
79+
```
80+
81+
```cpp
82+
// Before
83+
std::unordered_set<const VarNode*> vars;
84+
vars.insert(loop->loop_var.get());
85+
bool uses = UsesVar(expr, [&](const VarNode* vn) {
86+
return vars.count(vn) > 0;
87+
});
88+
89+
// After
90+
VarSet vars;
91+
vars.insert(loop->loop_var);
92+
bool uses = UsesVar(expr, [&](const VarNode* vn) {
93+
return vars.count(GetRef<Var>(vn)) > 0;
94+
});
95+
```
96+
97+
## Review Checklist
98+
99+
When reviewing TileLang TIR passes, search for:
100+
101+
```bash
102+
rg -n "std::unordered_(set|map)<const .*Node \\*|const (For|SeqStmt|SBlock).*Node \\*|\\.get\\(\\) ==|\\.find\\([^\\n]*\\.get\\(\\)|\\.count\\([^\\n]*\\.get\\(\\)|\\.insert\\([^\\n]*\\.get\\(\\)" src/transform
103+
```
104+
105+
Do not mechanically remove every raw node pointer. Keep visitor signatures, `as<TNode>()` pattern checks, and `CopyOnWrite()` mutation pointers. Refactor only the places that store, pass, compare, or key identities through raw pointers.

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,19 @@ jobs:
4444
fetch-depth: 0
4545
submodules: recursive
4646

47-
- name: Setup Python 3.9
47+
- name: Setup Python 3.10
4848
id: setup-pylowest
4949
uses: actions/setup-python@v6
5050
with:
51-
python-version: "3.9"
51+
python-version: "3.10"
5252
update-environment: true
5353
cache: pip
5454
cache-dependency-path: |
5555
pyproject.toml
5656
requirements*.txt
5757
.pre-commit-config.yaml
5858
59-
- name: Check AST with Python 3.9
59+
- name: Check AST with Python 3.10
6060
run: |
6161
"${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang
6262

.github/workflows/dist.yml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ on:
1818
- CMakeLists.txt
1919
- version_provider.py
2020
- .github/workflows/dist.yml
21-
# temporarily add to dist check
22-
# until we have type checking in ci / move to python 3.10
21+
# Type aliases can affect package import/build behavior.
2322
- tilelang/_typing.py
2423
release:
2524
types:
@@ -115,12 +114,11 @@ jobs:
115114
strategy:
116115
matrix:
117116
target:
118-
# Build wheels for different Python ABIs.
119-
# Windows CUDA 13.0 uses cp310 because PyTorch cu130 does not publish cp39 wheels.
120-
- { runner: ubuntu-latest, toolkit: "CUDA-12.8", test_backends: "cu118 cu130", python_version: "3.9" }
121-
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8", test_backends: "cu126 cu130", python_version: "3.9" }
117+
# Build wheels for the minimum supported Python ABI.
118+
- { runner: ubuntu-latest, toolkit: "CUDA-12.8", test_backends: "cu118 cu130", python_version: "3.10" }
119+
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8", test_backends: "cu126 cu130", python_version: "3.10" }
122120
- { runner: windows-latest, toolkit: "CUDA-13.0", test_backends: "cu130", python_version: "3.10" }
123-
- { runner: macos-latest, toolkit: "Metal", python_version: "3.9" }
121+
- { runner: macos-latest, toolkit: "Metal", python_version: "3.10" }
124122
# - "3.14t" # let user to build from source for now
125123
# TODO: Add cp315-abi3.abi3t after PEP 803
126124
fail-fast: false

0 commit comments

Comments
 (0)