Skip to content

Commit c0122a6

Browse files
committed
[Contrib] Fix CUDA contrib build after FFI/header cleanups
Six CUDA sources in src/runtime/contrib used LOG(FATAL) via transitive includes that #19483 trimmed; add the explicit <tvm/runtime/logging.h> include to thrust.cu, attention_kernels.cu, and the four cutlass kernel headers (fp16/fp8 sm90/sm100, gemm_runner, fp8_groupwise_scaled_gemm). cache_kernels.cu used the bare Array{...} alias that #19483 removed; switch to ffi::Array<Tensor>{...}. attention_kernels.cu registered FFI functions whose parameters were raw DLTensor*; the new reflection registry requires TypeSchema, so wrap both TVM_FFI_STATIC_INIT_BLOCK registrations to take Tensor and forward to the unchanged launchers via GetDLTensorPtr() (with const_cast for the output tensors, matching the mt_random_engine / cudnn pattern).
1 parent c0406a5 commit c0122a6

7 files changed

Lines changed: 47 additions & 14 deletions

File tree

src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
* under the License.
1818
*/
1919

20+
#include <tvm/runtime/logging.h>
21+
2022
#include <fstream>
2123
#include <iostream>
2224
#include <sstream>

src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
* under the License.
1818
*/
1919

20+
#include <tvm/runtime/logging.h>
21+
2022
#include <fstream>
2123
#include <iostream>
2224
#include <sstream>

src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <float.h>
2222
#include <tvm/ffi/extra/c_env_api.h>
2323
#include <tvm/ffi/function.h>
24+
#include <tvm/runtime/logging.h>
2425
#include <tvm/runtime/tensor.h>
2526

2627
#include "cutlass/bfloat16.h"

src/runtime/contrib/cutlass/gemm_runner.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
* under the License.
1818
*/
1919

20+
#include <tvm/runtime/logging.h>
21+
2022
#include <fstream>
2123
#include <iostream>
2224
#include <sstream>

src/runtime/contrib/thrust/thrust.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <tvm/ffi/extra/c_env_api.h>
3636
#include <tvm/ffi/function.h>
3737
#include <tvm/ffi/reflection/registry.h>
38+
#include <tvm/runtime/logging.h>
3839

3940
#include <algorithm>
4041
#include <functional>

src/runtime/contrib/vllm/attention_kernels.cu

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include <float.h>
3838
#include <tvm/ffi/function.h>
3939
#include <tvm/ffi/reflection/registry.h>
40+
#include <tvm/runtime/logging.h>
4041
#include <tvm/runtime/tensor.h>
4142

4243
#include <algorithm>
@@ -756,10 +757,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
756757
namespace refl = tvm::ffi::reflection;
757758
refl::GlobalDef().def(
758759
"tvm.contrib.vllm.single_query_cached_kv_attention",
759-
[](const DLTensor* query, const DLTensor* key_cache, const DLTensor* value_cache,
760-
const DLTensor* block_tables, const DLTensor* context_lens, int block_size,
761-
const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer
762-
DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out, DLTensor* out) {
760+
[](Tensor query, Tensor key_cache, Tensor value_cache, Tensor block_tables,
761+
Tensor context_lens, int block_size,
762+
Tensor max_context_len_tensor, // TODO(masahi): pass integer
763+
Tensor exp_sums, Tensor max_logits, Tensor tmp_out, Tensor out) {
763764
int num_seqs = query->shape[0];
764765
int num_heads = query->shape[1];
765766
int max_context_len = static_cast<int*>(max_context_len_tensor->data)[0];
@@ -768,13 +769,19 @@ TVM_FFI_STATIC_INIT_BLOCK() {
768769
bool use_v1 =
769770
max_context_len <= 8192 && (max_num_partitions == 1 || num_seqs * num_heads > 512);
770771
if (use_v1) {
771-
single_query_cached_kv_attention_v1(query, key_cache, value_cache, block_tables,
772-
context_lens, block_size, max_context_len_tensor,
773-
out);
772+
single_query_cached_kv_attention_v1(
773+
query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(), value_cache.GetDLTensorPtr(),
774+
block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(), block_size,
775+
max_context_len_tensor.GetDLTensorPtr(), const_cast<DLTensor*>(out.GetDLTensorPtr()));
774776
} else {
775-
single_query_cached_kv_attention_v2(query, key_cache, value_cache, block_tables,
776-
context_lens, block_size, max_context_len_tensor,
777-
exp_sums, max_logits, tmp_out, out);
777+
single_query_cached_kv_attention_v2(
778+
query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(), value_cache.GetDLTensorPtr(),
779+
block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(), block_size,
780+
max_context_len_tensor.GetDLTensorPtr(),
781+
const_cast<DLTensor*>(exp_sums.GetDLTensorPtr()),
782+
const_cast<DLTensor*>(max_logits.GetDLTensorPtr()),
783+
const_cast<DLTensor*>(tmp_out.GetDLTensorPtr()),
784+
const_cast<DLTensor*>(out.GetDLTensorPtr()));
778785
}
779786
});
780787
}
@@ -784,9 +791,27 @@ TVM_FFI_STATIC_INIT_BLOCK() {
784791
namespace refl = tvm::ffi::reflection;
785792
refl::GlobalDef()
786793
.def("tvm.contrib.vllm.single_query_cached_kv_attention_v1",
787-
single_query_cached_kv_attention_v1)
794+
[](Tensor query, Tensor key_cache, Tensor value_cache, Tensor block_tables,
795+
Tensor context_lens, int block_size, Tensor max_context_len_tensor, Tensor out) {
796+
single_query_cached_kv_attention_v1(
797+
query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(), value_cache.GetDLTensorPtr(),
798+
block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(), block_size,
799+
max_context_len_tensor.GetDLTensorPtr(),
800+
const_cast<DLTensor*>(out.GetDLTensorPtr()));
801+
})
788802
.def("tvm.contrib.vllm.single_query_cached_kv_attention_v2",
789-
single_query_cached_kv_attention_v2);
803+
[](Tensor query, Tensor key_cache, Tensor value_cache, Tensor block_tables,
804+
Tensor context_lens, int block_size, Tensor max_context_len_tensor, Tensor exp_sums,
805+
Tensor max_logits, Tensor tmp_out, Tensor out) {
806+
single_query_cached_kv_attention_v2(
807+
query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(), value_cache.GetDLTensorPtr(),
808+
block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(), block_size,
809+
max_context_len_tensor.GetDLTensorPtr(),
810+
const_cast<DLTensor*>(exp_sums.GetDLTensorPtr()),
811+
const_cast<DLTensor*>(max_logits.GetDLTensorPtr()),
812+
const_cast<DLTensor*>(tmp_out.GetDLTensorPtr()),
813+
const_cast<DLTensor*>(out.GetDLTensorPtr()));
814+
});
790815
}
791816

792817
} // namespace runtime

src/runtime/contrib/vllm/cache_kernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
154154
static_cast<const int*>(slot_mapping->data), key_stride, value_stride, num_heads,
155155
head_size, block_size, vec_size);
156156

157-
return Array{key_cache, value_cache};
157+
return ffi::Array<Tensor>{key_cache, value_cache};
158158
})
159159
.def("tvm.contrib.vllm.reconstruct_from_cache",
160160
[](Tensor key_cache, Tensor value_cache, Tensor slot_mapping) {
@@ -182,7 +182,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
182182
static_cast<scalar_t*>(value->data), key_stride, value_stride, num_heads,
183183
head_size, block_size, vec_size);
184184

185-
return Array{key, value};
185+
return ffi::Array<Tensor>{key, value};
186186
})
187187
.def("tvm.contrib.vllm.copy_blocks", [](ffi::Array<Tensor> key_value_caches,
188188
Tensor block_mapping) {

0 commit comments

Comments
 (0)