Skip to content

Commit a7bfc85

Browse files
[FIX] Inline ceil_log2 in gpu_2d_continuous_cumsum to fix MakePackedAPI error (#18957)
- The intermediate variable `ceil_log2` in `gpu_2d_continuous_cumsum` created a `LetStmt`-bound `Var` in the TIR function - When `MakePackedAPI` processed the function, it reported `ceil_log2` as an undefined variable not passed as an API argument - Inline the expression directly into `total_rounds` to avoid the intermediate `Var` — the computation is identical ## Test plan - Compile a model that uses GPU sampling (e.g. any LLM with top-p sampling on Metal) and verify compilation succeeds - The error this fixes: `Check failed: undefined.size() == 0: In PrimFunc gpu_2d_continuous_cumsum variables [ceil_log2] are used, but are not passed in as API arguments` Co-authored-by: Akaash Parthasarathy <43900735+akaashrp@users.noreply.github.com>
1 parent ec0daad commit a7bfc85

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

  • python/tvm/relax/backend/gpu_generic

python/tvm/relax/backend/gpu_generic/cumsum.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ def cumsum(var_a: T.handle, var_out: T.handle):
159159
A = T.match_buffer(var_a, [m, n], dtype=in_dtype)
160160
Out = T.match_buffer(var_out, [m, n], dtype=out_dtype)
161161
Tmp = T.alloc_buffer([m, n], dtype=out_dtype)
162-
ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
163-
total_rounds = ceil_log2 // LOG_BLOCK_N
162+
total_rounds = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n)))) // LOG_BLOCK_N
164163

165164
block_inclusive_inside_block(
166165
m, n, A, Out, Tmp, src_offset=T.int64(0), tmp_offset=T.int64(0)

0 commit comments

Comments
 (0)