Skip to content

Commit fb8eba4

Browse files
committed
issue/1124: minicpm-sala model
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 6e88052 commit fb8eba4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3574
-21
lines changed

.gitmodules

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@
55
path = third_party/nlohmann_json
66
url = https://github.com/nlohmann/json.git
77
branch = master
8+
[submodule "third_party/infllmv2_cuda_impl"]
9+
path = third_party/infllmv2_cuda_impl
10+
url = https://github.com/Ceng23333/infllmv2_cuda_impl.git
11+
branch = minicpm_sala_patches

include/infinicore/ops.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,20 @@
1919
#include "ops/flash_attention.hpp"
2020
#include "ops/fmin.hpp"
2121
#include "ops/fmod.hpp"
22+
#include "ops/simple_gla_attention.hpp"
23+
#include "ops/simple_gla_decode_step.hpp"
24+
#include "ops/simple_gla_recurrent_state_append.hpp"
25+
#include "ops/simple_gla_prefill.hpp"
26+
#include "ops/infllmv2_attention.hpp"
2227
#include "ops/hardswish.hpp"
2328
#include "ops/hardtanh.hpp"
2429
#include "ops/kv_caching.hpp"
2530
#include "ops/matmul.hpp"
31+
#include "ops/mha_kvcache.hpp"
32+
#include "ops/mha_varlen.hpp"
33+
#include "ops/mul.hpp"
2634
#include "ops/ones.hpp"
35+
#include "ops/zeros.hpp"
2736
#include "ops/paged_attention.hpp"
2837
#include "ops/paged_attention_prefill.hpp"
2938
#include "ops/paged_caching.hpp"
@@ -34,6 +43,7 @@
3443
#include "ops/reciprocal.hpp"
3544
#include "ops/rms_norm.hpp"
3645
#include "ops/rope.hpp"
46+
#include "ops/sigmoid.hpp"
3747
#include "ops/silu.hpp"
3848
#include "ops/silu_and_mul.hpp"
3949
#include "ops/swiglu.hpp"
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/**
2+
* C++ API declarations for InfLLM-V2 attention kernels.
3+
* When ENABLE_INFLLMV2 is defined, link against the InfLLM-V2 library
4+
* (e.g. from infllmv2_cuda_impl) that provides these symbols.
5+
* Requires ENABLE_ATEN for at::Tensor.
6+
* Symbols are in global namespace to match entry.cu.
7+
*/
8+
#pragma once
9+
10+
#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN)
11+
12+
#include <ATen/ATen.h>
13+
#include <c10/util/Optional.h>
14+
#include <vector>
15+
16+
/** Varlen forward: unpadded Q/K/V with cu_seqlens. Returns {out, softmax_lse, ...}. */
17+
std::vector<at::Tensor> mha_varlen_fwd(
18+
at::Tensor &q,
19+
const at::Tensor &k,
20+
const at::Tensor &v,
21+
c10::optional<at::Tensor> &out_,
22+
const at::Tensor &cu_seqlens_q,
23+
const at::Tensor &cu_seqlens_k,
24+
c10::optional<at::Tensor> &seqused_k,
25+
c10::optional<const at::Tensor> &leftpad_k_,
26+
c10::optional<at::Tensor> &block_table_,
27+
c10::optional<at::Tensor> &alibi_slopes_,
28+
int max_seqlen_q,
29+
int max_seqlen_k,
30+
float p_dropout,
31+
float softmax_scale,
32+
bool zero_tensors,
33+
bool is_causal,
34+
int window_size_left,
35+
int window_size_right,
36+
float softcap,
37+
bool return_softmax,
38+
c10::optional<at::Generator> gen_,
39+
c10::optional<at::Tensor> &blockmask_);
40+
41+
/** KV-cache forward (decode). Returns {out, softmax_lse}. */
42+
std::vector<at::Tensor> mha_fwd_kvcache(
43+
at::Tensor &q,
44+
const at::Tensor &kcache,
45+
const at::Tensor &vcache,
46+
c10::optional<const at::Tensor> &k_,
47+
c10::optional<const at::Tensor> &v_,
48+
c10::optional<const at::Tensor> &seqlens_k_,
49+
c10::optional<const at::Tensor> &rotary_cos_,
50+
c10::optional<const at::Tensor> &rotary_sin_,
51+
c10::optional<const at::Tensor> &cache_batch_idx_,
52+
c10::optional<const at::Tensor> &leftpad_k_,
53+
c10::optional<at::Tensor> &block_table_,
54+
c10::optional<at::Tensor> &alibi_slopes_,
55+
c10::optional<at::Tensor> &out_,
56+
float softmax_scale,
57+
bool is_causal,
58+
int window_size_left,
59+
int window_size_right,
60+
float softcap,
61+
bool is_rotary_interleaved,
62+
int num_splits,
63+
c10::optional<at::Tensor> &blockmask_);
64+
65+
#endif // ENABLE_INFLLMV2 && ENABLE_ATEN
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
#include <optional>
6+
7+
namespace infinicore::op {
8+
9+
// Varlen InfLLM-V2 attention over unpadded Q/K/V.
10+
//
11+
// Shapes follow the FlashAttn-style varlen convention:
12+
// q : [total_q, nheads, head_dim]
13+
// k, v : [total_k, nheads_k, head_dim]
14+
// cu_seqlens_q: [batch_size + 1] (int32)
15+
// cu_seqlens_k: [batch_size + 1] (int32)
16+
//
17+
// Returns:
18+
// [total_q, nheads, head_dim]
19+
Tensor infllmv2_varlen(const Tensor &q,
20+
const Tensor &k,
21+
const Tensor &v,
22+
const Tensor &cu_seqlens_q,
23+
const Tensor &cu_seqlens_k,
24+
int max_seqlen_q,
25+
int max_seqlen_k,
26+
float scale,
27+
bool causal,
28+
int window_size_left = -1,
29+
int window_size_right = -1);
30+
31+
// Decode-time InfLLM-V2 attention with KV cache.
32+
//
33+
// Shapes:
34+
// q : [batch, seqlen_q, nheads, head_dim]
35+
// k_cache : [num_blocks, block_size, nheads_k, head_dim] or [batch, seqlen_cache, nheads_k, head_dim]
36+
// v_cache : same as k_cache
37+
// cache_lens : [batch] (int32) total KV length per sequence
38+
//
39+
// Returns:
40+
// [batch, seqlen_q, nheads, head_dim]
41+
Tensor infllmv2_kvcache(const Tensor &q,
42+
const Tensor &k_cache,
43+
const Tensor &v_cache,
44+
const Tensor &cache_lens,
45+
float scale,
46+
bool causal,
47+
int window_size_left = -1,
48+
int window_size_right = -1);
49+
50+
// Decode-time InfLLM-V2 attention with KV cache, updating cache in-place.
51+
//
52+
// Shapes:
53+
// q : [batch, seqlen_q, nheads, head_dim]
54+
// k_cache : [batch, seqlen_cache, nheads_k, head_dim] (dense cache)
55+
// v_cache : same as k_cache
56+
// k_new/v_new: [batch, seqlen_new, nheads_k, head_dim] (new KV to append at cache_lens offsets)
57+
// cache_lens : [batch] (int32) current KV length per sequence BEFORE appending
58+
//
59+
// Returns:
60+
// [batch, seqlen_q, nheads, head_dim]
61+
Tensor infllmv2_kvcache_update(const Tensor &q,
62+
const Tensor &k_cache,
63+
const Tensor &v_cache,
64+
const Tensor &k_new,
65+
const Tensor &v_new,
66+
const Tensor &cache_lens,
67+
float scale,
68+
bool causal,
69+
int window_size_left = -1,
70+
int window_size_right = -1);
71+
72+
} // namespace infinicore::op
73+

include/infinicore/ops/sigmoid.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class Sigmoid {
8+
public:
9+
using schema = void (*)(Tensor, Tensor);
10+
static void execute(Tensor output, Tensor input);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor sigmoid(Tensor input);
15+
void sigmoid_(Tensor output, Tensor input);
16+
} // namespace infinicore::op
17+
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "common/op.hpp"
6+
7+
namespace infinicore::op {
8+
9+
// Simple GLA (recurrent linear) attention with per-head decay.
10+
// Shapes: q, k, v [B, T, H, D], g_gamma [H] (log-decay per head).
11+
// Recurrence: gate = exp(g_gamma); S = S * gate + outer(k_t, v_t); o_t = (q_t * scale) @ S.
12+
// Returns [B, T, H, D].
13+
class SimpleGlaAttention {
14+
public:
15+
using schema = void (*)(Tensor & out, const Tensor &q, const Tensor &k, const Tensor &v,
16+
const Tensor &g_gamma, float scale);
17+
static void execute(Tensor & out, const Tensor &q, const Tensor &k, const Tensor &v,
18+
const Tensor &g_gamma, float scale);
19+
static common::OpDispatcher<schema> &dispatcher();
20+
};
21+
22+
Tensor simple_gla_attention(const Tensor &q,
23+
const Tensor &k,
24+
const Tensor &v,
25+
const Tensor &g_gamma,
26+
float scale);
27+
28+
} // namespace infinicore::op
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "common/op.hpp"
6+
7+
namespace infinicore::op {
8+
9+
// One decode timestep of Simple GLA (same recurrence as SimpleGlaAttention).
10+
// q, k, v: [B, 1, H, D]; g_gamma: [H] (log-decay per head); state: [B, H, D, D] float32 (in-place).
11+
// Updates: state = state * exp(g_gamma) + outer(k, v); then out[b,0,h,:] = (q * scale) @ state[b,h].
12+
// Returns out with shape [B, 1, H, D] (same dtype as q).
13+
class SimpleGlaDecodeStep {
14+
public:
15+
using schema = void (*)(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, const Tensor &v,
16+
const Tensor &g_gamma, float scale);
17+
static void execute(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, const Tensor &v,
18+
const Tensor &g_gamma, float scale);
19+
static common::OpDispatcher<schema> &dispatcher();
20+
};
21+
22+
Tensor simple_gla_decode_step(const Tensor &q, const Tensor &k, const Tensor &v, Tensor &state,
23+
const Tensor &g_gamma, float scale);
24+
25+
} // namespace infinicore::op
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "../tensor.hpp"
6+
#include "common/op.hpp"
7+
8+
namespace infinicore::op {
9+
10+
INFINICORE_GRAPH_OP_CLASS(SimpleGLAPrefill,
11+
Tensor,
12+
const Tensor &,
13+
const Tensor &,
14+
const Tensor &,
15+
const Tensor &,
16+
float);
17+
18+
// Fused/chunked Simple GLA prefill forward.
19+
// q,k,v: [B,T,H,D] (F16/BF16), g_gamma: [H] (F32), returns [B,T,H,D] (same dtype).
20+
Tensor simple_gla_prefill(const Tensor &q,
21+
const Tensor &k,
22+
const Tensor &v,
23+
const Tensor &g_gamma,
24+
float scale);
25+
26+
} // namespace infinicore::op
27+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../tensor.hpp"
5+
#include "common/dispatcher.hpp"
6+
7+
namespace infinicore::op {
8+
9+
// Batched update of Simple GLA recurrent state (float32 [B,H,D,D]) for a contiguous
10+
// K/V segment [B,L,H,D], matching L repeated simple_gla_decode_step applications:
11+
// S <- g^L * S + sum_{j=0}^{L-1} g^{L-1-j} * outer(k_j, v_j)
12+
// g_gamma: [H] (same log-gate as simple_gla_decode_step; gate = exp(g_gamma)).
13+
class SimpleGlaRecurrentStateAppend {
14+
public:
15+
using schema = void (*)(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, const Tensor &g_gamma);
16+
static void execute(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, const Tensor &g_gamma);
17+
static common::OpDispatcher<schema> &dispatcher();
18+
};
19+
20+
void simple_gla_recurrent_state_append_segment(Tensor &state, const Tensor &k_seg, const Tensor &v_seg,
21+
const Tensor &g_gamma);
22+
23+
} // namespace infinicore::op

include/infinicore/ops/zeros.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
5+
namespace infinicore::op {
6+
class Zeros {
7+
8+
public:
9+
using schema = void (*)(Tensor);
10+
static void execute(Tensor output);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
void zeros_(Tensor output);
15+
} // namespace infinicore::op

0 commit comments

Comments
 (0)