Skip to content

Commit dd12102

Browse files
committed
issue/889 - added interface definitions
1 parent 3b5afff commit dd12102

23 files changed

Lines changed: 1067 additions & 69 deletions

include/infinicore/ops.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
#include "ops/add_rms_norm.hpp"
55
#include "ops/attention.hpp"
66
#include "ops/causal_softmax.hpp"
7+
#include "ops/embedding.hpp"
8+
#include "ops/flash_attention.hpp"
9+
#include "ops/kv_caching.hpp"
710
#include "ops/matmul.hpp"
811
#include "ops/ones.hpp"
912
#include "ops/paged_attention.hpp"
1013
#include "ops/paged_attention_prefill.hpp"
1114
#include "ops/paged_caching.hpp"
1215
#include "ops/random_sample.hpp"
16+
#include "ops/random_sample_batched.hpp"
1317
#include "ops/rearrange.hpp"
1418
#include "ops/rms_norm.hpp"
1519
#include "ops/rope.hpp"

include/infinicore/ops/embedding.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
namespace infinicore::op {
66

7+
class Embedding {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor);
10+
static void execute(Tensor out, Tensor input, Tensor weight);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
714
Tensor embedding(Tensor input, Tensor weight);
815
void embedding_(Tensor out, Tensor input, Tensor weight);
916
} // namespace infinicore::op
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class FlashAttention {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, bool, size_t);
10+
static void execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal, size_t pos);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal, size_t pos);
15+
void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal, size_t pos);
16+
} // namespace infinicore::op
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class KVCaching {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int);
10+
static void execute(Tensor k_cache,
11+
Tensor v_cache,
12+
Tensor k,
13+
Tensor v,
14+
Tensor offsets,
15+
Tensor cache_lengths,
16+
Tensor cache_ids,
17+
int max_cache_size);
18+
static common::OpDispatcher<schema> &dispatcher();
19+
};
20+
21+
Tensor kv_caching(Tensor k_cache,
22+
Tensor v_cache,
23+
Tensor k,
24+
Tensor v,
25+
Tensor offsets,
26+
Tensor cache_lengths,
27+
Tensor cache_ids,
28+
int max_cache_size);
29+
void kv_caching_(Tensor k_cache,
30+
Tensor v_cache,
31+
Tensor k,
32+
Tensor v,
33+
Tensor offsets,
34+
Tensor cache_lengths,
35+
Tensor cache_ids,
36+
int max_cache_size);
37+
} // namespace infinicore::op
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class RandomSampleBatched {
9+
public:
10+
using schema = void (*)(Tensor, Tensor, const float *, const float *, const int *, const float *, int);
11+
static void execute(Tensor result, Tensor probs, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
// Out-of-place API
16+
Tensor random_sample_batched(Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
17+
// In-place API
18+
void random_sample_batched_(Tensor indices, Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
19+
20+
} // namespace infinicore::op

include/infiniop.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#include "infiniop/ops/clip.h"
1010
#include "infiniop/ops/conv.h"
1111
#include "infiniop/ops/dequantize_awq.h"
12+
#include "infiniop/ops/embedding.h"
13+
#include "infiniop/ops/flash_attention.h"
1214
#include "infiniop/ops/gelu.h"
1315
#include "infiniop/ops/gemm.h"
16+
#include "infiniop/ops/kv_caching.h"
1417
#include "infiniop/ops/layer_norm.h"
1518
#include "infiniop/ops/logsoftmax.h"
1619
#include "infiniop/ops/lp_norm.h"
@@ -20,6 +23,7 @@
2023
#include "infiniop/ops/paged_attention_prefill.h"
2124
#include "infiniop/ops/paged_caching.h"
2225
#include "infiniop/ops/random_sample.h"
26+
#include "infiniop/ops/random_sample_batched.h"
2327
#include "infiniop/ops/rearrange.h"
2428
#include "infiniop/ops/relu.h"
2529
#include "infiniop/ops/rms_norm.h"

include/infiniop/ops/embedding.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef __INFINIOP_EMBEDDING_API_H__
2+
#define __INFINIOP_EMBEDDING_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopEmbeddingDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t output_desc,
12+
infiniopTensorDescriptor_t input_desc,
13+
infiniopTensorDescriptor_t weight_desc);
14+
15+
__C __export infiniStatus_t infiniopEmbedding(
16+
infiniopEmbeddingDescriptor_t desc,
17+
void *output,
18+
const void *input,
19+
const void *weight,
20+
void *stream);
21+
22+
__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor(
23+
infiniopEmbeddingDescriptor_t desc);
24+
25+
#endif
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
2+
#define __INFINIOP_FLASH_ATTENTION_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopFlashAttentionDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t out_desc,
12+
infiniopTensorDescriptor_t q_desc,
13+
infiniopTensorDescriptor_t k_desc,
14+
infiniopTensorDescriptor_t v_desc,
15+
float scale,
16+
char is_causal,
17+
size_t pos);
18+
19+
__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
20+
infiniopFlashAttentionDescriptor_t desc,
21+
size_t *size);
22+
23+
__C __export infiniStatus_t infiniopFlashAttention(
24+
infiniopFlashAttentionDescriptor_t desc,
25+
void *workspace,
26+
size_t workspace_size,
27+
void *out,
28+
const void *q,
29+
const void *k,
30+
const void *v,
31+
void *stream);
32+
33+
__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
34+
infiniopFlashAttentionDescriptor_t desc);
35+
#endif

include/infiniop/ops/kv_caching.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#ifndef __INFINIOP_KV_CACHING_API_H__
2+
#define __INFINIOP_KV_CACHING_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateKVCachingDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopKVCachingDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t k_cache,
12+
infiniopTensorDescriptor_t v_cache,
13+
infiniopTensorDescriptor_t k,
14+
infiniopTensorDescriptor_t v,
15+
infiniopTensorDescriptor_t offsets,
16+
infiniopTensorDescriptor_t cache_lengths,
17+
infiniopTensorDescriptor_t cache_ids,
18+
int max_cache_size);
19+
20+
__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size);
21+
22+
__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc,
23+
void *workspace,
24+
size_t workspace_size,
25+
const void *k_src,
26+
const void *v_src,
27+
void *k_dst,
28+
void *v_dst,
29+
void *stream);
30+
31+
__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc);
32+
33+
#endif

include/infiniop/ops/random_sample.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
1515
infiniopRandomSampleDescriptor_t desc,
1616
size_t *size);
1717

18-
__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor(
19-
infiniopHandle_t handle,
20-
infiniopRandomSampleDescriptor_t *desc_ptr,
21-
infiniopTensorDescriptor_t result,
22-
infiniopTensorDescriptor_t probs);
23-
2418
__C __export infiniStatus_t infiniopRandomSample(
2519
infiniopRandomSampleDescriptor_t desc,
2620
void *workspace,

0 commit comments

Comments
 (0)