@@ -17,7 +17,7 @@ class Descriptor final : public InfiniopDescriptor {
1717 infiniopTensorDescriptor_t q_desc,
1818 infiniopTensorDescriptor_t k_desc,
1919 infiniopTensorDescriptor_t v_desc,
20- std:: size_t total_kv_len,
20+ infiniopTensorDescriptor_t total_kv_len,
2121 double scale,
2222 char is_causal) : InfiniopDescriptor{handle->device , handle->device_id },
2323 _query_shape{q_desc->shape ()},
@@ -26,12 +26,12 @@ class Descriptor final : public InfiniopDescriptor {
2626 _key_strides{k_desc->strides ()},
2727 _value_shape{v_desc->shape ()},
2828 _value_strides{v_desc->strides ()},
29+ _total_kv_shape{total_kv_len->shape ()},
30+ _total_kv_strides{total_kv_len->strides ()},
2931 _output_strides{out_desc->strides ()},
3032 _dtype{q_desc->dtype ()},
3133 _scale{scale},
3234 _is_causal{is_causal} {
33- _key_shape[_key_shape.size () - 2 ] = total_kv_len;
34- _value_shape[_key_shape.size () - 2 ] = total_kv_len;
3535 }
3636
3737 ~Descriptor () = default ;
@@ -46,13 +46,15 @@ class Descriptor final : public InfiniopDescriptor {
4646 const void *q,
4747 const void *k,
4848 const void *v,
49+ const void *total_kv_len,
4950 void *stream) const {
5051 uint64_t empty_shape[4 ];
5152 int64_t empty_strides[4 ];
5253
5354 auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}};
5455 auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}};
5556 auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}};
57+ auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}};
5658
5759 NineToothedTensor attn_mask{nullptr , empty_shape, empty_strides};
5860 NineToothedTensor is_causal;
@@ -75,6 +77,7 @@ class Descriptor final : public InfiniopDescriptor {
7577 query,
7678 key,
7779 value,
80+ total_kv_length,
7881 attn_mask,
7982 is_causal,
8083 scale,
@@ -101,7 +104,7 @@ class Descriptor final : public InfiniopDescriptor {
101104 infiniopTensorDescriptor_t q_desc,
102105 infiniopTensorDescriptor_t k_desc,
103106 infiniopTensorDescriptor_t v_desc,
104- std:: size_t total_kv_len,
107+ infiniopTensorDescriptor_t total_kv_len,
105108 double scale,
106109 char is_causal) {
107110 *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal};
@@ -126,6 +129,10 @@ class Descriptor final : public InfiniopDescriptor {
126129
127130 std::vector<Stride> _value_strides;
128131
132+ std::vector<Size> _total_kv_shape;
133+
134+ std::vector<Stride> _total_kv_strides;
135+
129136 std::vector<Stride> _output_strides;
130137
131138 infiniDtype_t _dtype;
0 commit comments