Skip to content

Commit 9c4d486

Browse files
committed
fix rebased
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 7870e55 commit 9c4d486

File tree

6 files changed

+24
-57
lines changed

6 files changed

+24
-57
lines changed

include/infinicore/adaptor/aten_adaptor.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
#include <ATen/ATen.h>
77

8-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
8+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
99
#include <c10/cuda/CUDAStream.h>
1010
#endif
1111

@@ -32,7 +32,9 @@ inline at::ScalarType to_at_dtype(DataType dtype) {
3232
}
3333

3434
inline at::Device to_at_device(const Device &device) {
35-
if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX) {
35+
// PyTorch ATen only exposes standard device types (e.g. kCPU/kCUDA).
36+
// Treat MetaX/QY devices as CUDA devices for ATen tensor interoperability.
37+
if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY) {
3638
return at::Device(at::kCUDA, device.getIndex());
3739
} else if (device.getType() == Device::Type::CPU) {
3840
return at::Device(at::kCPU);
@@ -43,7 +45,7 @@ inline at::Device to_at_device(const Device &device) {
4345

4446
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
4547

46-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
48+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
4749
c10::cuda::CUDAStream get_cuda_stream();
4850
#endif
4951
} // namespace infinicore::adaptor

src/infinicore/adaptor/aten_adaptor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
3232
options);
3333
}
3434

35-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
35+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
3636
c10::cuda::CUDAStream get_cuda_stream() {
3737
return c10::cuda::getStreamFromExternal(
3838
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());

src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,22 @@ void run(void *planned_meta) {
5252
#endif
5353
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
5454

55-
// FlashAttention kernels expect standard dense layout (contiguous last dimension).
55+
// Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense.
5656
auto out_at = infinicore::adaptor::to_aten_tensor(p->out);
5757
const bool out_need_copy_back = !out_at.is_contiguous();
5858
auto out_tensor = out_need_copy_back ? out_at.contiguous() : out_at;
59-
auto q = infinicore::adaptor::to_aten_tensor(p->q).contiguous();
59+
auto q = infinicore::adaptor::to_aten_tensor(p->q);
60+
#if defined(ENABLE_NVIDIA_API)
61+
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
62+
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
63+
#elif defined(ENABLE_QY_API) || defined(ENABLE_METAX_API)
6064
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous();
6165
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous();
62-
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k).contiguous());
63-
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table).contiguous());
66+
#endif
67+
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
68+
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
6469
auto alibi_slopes = p->alibi_slopes
65-
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous())
70+
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
6671
: std::nullopt;
6772

6873
std::optional<const at::Tensor> k_new = std::nullopt;

src/infiniop/ops/equal/metax/equal_metax.maca

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,10 @@
11
#include "equal_metax.h"
22

33
#include "../../../elementwise/metax/elementwise_metax.h"
4-
#include <type_traits>
54

6-
namespace op::equal::metax {
7-
8-
struct EqualOp {
9-
static constexpr size_t num_inputs = 2;
5+
#include "../cuda/kernel.cuh"
106

11-
template <typename Tout, typename Tin0, typename Tin1>
12-
__device__ __forceinline__ bool operator()(const Tin0 &a, const Tin1 &b) const {
13-
if constexpr (std::is_same_v<Tin0, Tin1>) {
14-
return static_cast<Tout>(a == b);
15-
} else {
16-
return false;
17-
}
18-
}
19-
};
7+
namespace op::equal::metax {
208

219
Descriptor::~Descriptor() = default;
2210

@@ -25,54 +13,44 @@ infiniStatus_t Descriptor::create(
2513
Descriptor **desc_ptr,
2614
infiniopTensorDescriptor_t out_desc,
2715
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
28-
2916
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
30-
3117
const auto &a_desc = input_desc_vec.at(0);
3218
auto compute_dtype = a_desc->dtype();
3319
auto out_dtype = out_desc->dtype();
34-
3520
const auto &b_desc = input_desc_vec.at(1);
3621
const auto &c_shape = out_desc->shape();
3722
const auto &a_shape = a_desc->shape();
3823
const auto &b_shape = b_desc->shape();
39-
4024
CHECK_DTYPE(compute_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16,
4125
INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_F64);
42-
4326
CHECK_DTYPE(out_dtype, INFINI_DTYPE_BOOL);
44-
4527
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
46-
4728
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, compute_dtype, out_desc, input_desc_vec)
48-
4929
return INFINI_STATUS_SUCCESS;
5030
}
51-
5231
infiniStatus_t Descriptor::calculate(
5332
void *workspace,
5433
size_t workspace_size,
5534
void *output,
5635
std::vector<const void *> inputs,
5736
void *stream) const {
58-
5937
if (workspace_size < _workspace_size) {
6038
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
6139
}
6240

6341
switch (_dtype) {
6442
case INFINI_DTYPE_F16:
65-
return _device_info->calculate<256, EqualOp, bool, half, half>(_info, workspace, output, inputs, stream);
43+
return _device_info->calculate<256, cuda::EqualOp, bool, half, half>(_info, workspace, output, inputs, stream);
6644
case INFINI_DTYPE_BF16:
67-
return _device_info->calculate<256, EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream);
45+
return _device_info->calculate<256, cuda::EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream);
6846
case INFINI_DTYPE_F32:
69-
return _device_info->calculate<256, EqualOp, bool, float, float>(_info, workspace, output, inputs, stream);
47+
return _device_info->calculate<256, cuda::EqualOp, bool, float, float>(_info, workspace, output, inputs, stream);
7048
case INFINI_DTYPE_I32:
71-
return _device_info->calculate<256, EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream);
49+
return _device_info->calculate<256, cuda::EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream);
7250
case INFINI_DTYPE_I64:
73-
return _device_info->calculate<256, EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream);
51+
return _device_info->calculate<256, cuda::EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream);
7452
case INFINI_DTYPE_F64:
75-
return _device_info->calculate<256, EqualOp, bool, double, double>(_info, workspace, output, inputs, stream);
53+
return _device_info->calculate<256, cuda::EqualOp, bool, double, double>(_info, workspace, output, inputs, stream);
7654
default:
7755
return INFINI_STATUS_BAD_TENSOR_DTYPE;
7856
}

src/infiniop/ops/hardswish/cuda/kernel.cuh

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,6 @@
22
#define __HARDSWISH_CUDA_H__
33

44
#include <cmath>
5-
#if defined(ENABLE_METAX_API)
6-
#include <hcr/hc_runtime_api.h>
7-
#elif defined(__MACACC__)
8-
#include <maca_bfloat16.h>
9-
#include <maca_fp16.h>
10-
#else
11-
#include <cuda_bf16.h>
12-
#include <cuda_fp16.h>
13-
#endif
145

156
namespace op::hardswish::cuda {
167

src/infiniop/ops/hardtanh/cuda/kernel.cuh

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
#ifndef __HARDTANH_CUDA_H__
22
#define __HARDTANH_CUDA_H__
33

4-
#if defined(ENABLE_METAX_API)
5-
#include <hcr/hc_runtime_api.h>
6-
#elif defined(__MACACC__)
7-
#include <maca_bfloat16.h>
8-
#include <maca_fp16.h>
9-
#else
10-
#include <cuda_bf16.h>
11-
#include <cuda_fp16.h>
12-
#endif
134
#include <type_traits>
145

156
namespace op::hardtanh::cuda {

0 commit comments

Comments
 (0)