|
| 1 | +--- |
| 2 | +name: tilelang-tvm-ir |
| 3 | +description: Use when editing TileLang C++ passes or TVM TIRX code that handles ObjectRef/NodeRef types such as For, Buffer, Var, SBlock, Stmt, PrimExpr, or their *Node raw node counterparts; especially when choosing function parameters, optional values, identity maps/sets, or equality checks. |
| 4 | +--- |
| 5 | + |
| 6 | +# TileLang TVM IR Handle Conventions |
| 7 | + |
| 8 | +## Core Rule |
| 9 | + |
| 10 | +In TVM C++, `For`, `Buffer`, `Var`, `SBlock`, `Stmt`, `PrimExpr`, `SeqStmt`, etc. are `ObjectRef` smart handles. `ForNode`, `BufferNode`, `VarNode`, `SBlockNode`, etc. are raw node structs reached through visitor callbacks, `as<TNode>()`, `operator->`, or `.get()`. |
| 11 | + |
| 12 | +When a value needs to cross a function boundary, be stored, be optional, be used as an identity key, or survive beyond a local inspection branch, prefer the handle type over `const *Node`. |
| 13 | + |
| 14 | +## Preferred Patterns |
| 15 | + |
| 16 | +- Function parameters and return values: use handles such as `For`, `Buffer`, `Var`, `SBlock`, `Stmt`, or `SeqStmt`. |
| 17 | +- Nullable AST values: use `Optional<For>`, `Optional<SeqStmt>`, etc., not `const ForNode* = nullptr`. |
| 18 | +- Identity maps and sets: use handle keys with TVM identity hashing: |
| 19 | + |
| 20 | +```cpp |
| 21 | +using BufferSet = std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>; |
| 22 | +using BufferMap = std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>; |
| 23 | +using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>; |
| 24 | +``` |
| 25 | + |
| 26 | +- Identity comparisons: use `.same_as(other)` when comparing two handles. |
| 27 | +- Visitor callback node pointers: convert to a handle with `GetRef<T>(op)` when the value must be retained or passed elsewhere. |
| 28 | + |
| 29 | +```cpp |
| 30 | +Stmt VisitStmt_(const ForNode* op) final { |
| 31 | + For loop = GetRef<For>(op); |
| 32 | + Optional<For> candidate = FindPipelineLoop(loop->body); |
| 33 | + if (candidate.defined() && candidate.value().same_as(loop)) { |
| 34 | + ... |
| 35 | + } |
| 36 | +} |
| 37 | +``` |
| 38 | +
|
| 39 | +- Pattern matching and local mutation may still use node pointers: |
| 40 | + - `if (const auto* seq = stmt.as<SeqStmtNode>()) { ... }` |
| 41 | + - `BufferStoreNode* n = store.CopyOnWrite();` |
| 42 | + - visitor overrides such as `VisitStmt_(const SeqStmtNode* op)` |
| 43 | +
|
| 44 | +Keep these raw pointers local to the immediate inspection or mutation site. |
| 45 | +
|
| 46 | +## Avoid |
| 47 | +
|
| 48 | +- Passing `const ForNode*`, `const BufferNode*`, `const VarNode*`, or `const SBlockNode*` between helper functions when a handle exists. |
| 49 | +- Storing raw node pointers in `std::unordered_map` or `std::unordered_set` for identity tracking. |
| 50 | +- Using `.get()` as a key unless a callee requires a raw TVM node API and the pointer is not retained. |
| 51 | +- Comparing handles through `.get() == other.get()`; prefer `.same_as()`. |
| 52 | +- Reconstructing handles from raw pointers repeatedly when a handle is already available. |
| 53 | +
|
| 54 | +## Common Refactors |
| 55 | +
|
| 56 | +```cpp |
| 57 | +// Before |
| 58 | +const SeqStmtNode* pipeline_body_seq = nullptr; |
| 59 | +pipeline_body_seq = seq_stmt; |
| 60 | +ICHECK(pipeline_body_seq != nullptr); |
| 61 | +
|
| 62 | +// After |
| 63 | +Optional<SeqStmt> pipeline_body_seq; |
| 64 | +pipeline_body_seq = GetRef<SeqStmt>(seq_stmt); |
| 65 | +ICHECK(pipeline_body_seq.defined()); |
| 66 | +SeqStmt pipeline_body = pipeline_body_seq.value(); |
| 67 | +``` |
| 68 | + |
| 69 | +```cpp |
| 70 | +// Before |
| 71 | +std::unordered_set<const BufferNode*> seen; |
| 72 | +seen.insert(buffer.get()); |
| 73 | +if (seen.count(read->buffer.get())) { ... } |
| 74 | + |
| 75 | +// After |
| 76 | +BufferSet seen; |
| 77 | +seen.insert(buffer); |
| 78 | +if (seen.count(read->buffer)) { ... } |
| 79 | +``` |
| 80 | +
|
| 81 | +```cpp |
| 82 | +// Before |
| 83 | +std::unordered_set<const VarNode*> vars; |
| 84 | +vars.insert(loop->loop_var.get()); |
| 85 | +bool uses = UsesVar(expr, [&](const VarNode* vn) { |
| 86 | + return vars.count(vn) > 0; |
| 87 | +}); |
| 88 | +
|
| 89 | +// After |
| 90 | +VarSet vars; |
| 91 | +vars.insert(loop->loop_var); |
| 92 | +bool uses = UsesVar(expr, [&](const VarNode* vn) { |
| 93 | + return vars.count(GetRef<Var>(vn)) > 0; |
| 94 | +}); |
| 95 | +``` |
| 96 | + |
| 97 | +## Review Checklist |
| 98 | + |
| 99 | +When reviewing TileLang TIR passes, search for: |
| 100 | + |
| 101 | +```bash |
| 102 | +rg -n "std::unordered_(set|map)<const .*Node \\*|const (For|SeqStmt|SBlock).*Node \\*|\\.get\\(\\) ==|\\.find\\([^\\n]*\\.get\\(\\)|\\.count\\([^\\n]*\\.get\\(\\)|\\.insert\\([^\\n]*\\.get\\(\\)" src/transform |
| 103 | +``` |
| 104 | + |
| 105 | +Do not mechanically remove every raw node pointer. Keep visitor signatures, `as<TNode>()` pattern checks, and `CopyOnWrite()` mutation pointers. Refactor only the places that store, pass, compare, or key identities through raw pointers. |
0 commit comments