forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattentionOp.cpp
More file actions
1575 lines (1450 loc) · 85.4 KB
/
Copy pathattentionOp.cpp
File metadata and controls
1575 lines (1450 loc) · 85.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/attentionOp.h"
#include "tensorrt_llm/common/attentionWorkspace.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/kernels/flashMLA/flash_mla.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/mlaKernels.h"
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include "tensorrt_llm/thop/attentionOp.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <cstdint>
#include <functional>
#include <torch/extension.h>
#include <type_traits>
#include <unordered_set>
TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
using tensorrt_llm::common::op::AttentionOp;
using tensorrt_llm::common::op::AttentionWorkspaceManager;
using tensorrt_llm::common::op::hash;
using tensorrt_llm::runtime::RequestType;
namespace
{
int64_t exportOffset(tensorrt_llm::common::op::WorkspaceSlice const& slice)
{
if (slice.size == 0)
{
return -1;
}
return static_cast<int64_t>(slice.offset);
}
template <typename T>
T readHostTensor2D(at::Tensor const& tensor, int64_t const row, int64_t const col, char const* tensorName)
{
TORCH_CHECK(tensor.device().is_cpu(), tensorName, " must be a CPU tensor.");
TORCH_CHECK(tensor.dim() == 2, tensorName, " must be a 2D tensor.");
TORCH_CHECK(row >= 0 && row < tensor.size(0), tensorName, " row is out of bounds.");
TORCH_CHECK(col >= 0 && col < tensor.size(1), tensorName, " column is out of bounds.");
auto const* data = tensor.data_ptr<T>();
return data[row * tensor.stride(0) + col * tensor.stride(1)];
}
template <typename T>
T readHostTensor3D(at::Tensor const& tensor, int64_t const i, int64_t const j, int64_t const k, char const* tensorName)
{
TORCH_CHECK(tensor.device().is_cpu(), tensorName, " must be a CPU tensor.");
TORCH_CHECK(tensor.dim() == 3, tensorName, " must be a 3D tensor.");
TORCH_CHECK(i >= 0 && i < tensor.size(0), tensorName, " dim 0 index is out of bounds.");
TORCH_CHECK(j >= 0 && j < tensor.size(1), tensorName, " dim 1 index is out of bounds.");
TORCH_CHECK(k >= 0 && k < tensor.size(2), tensorName, " dim 2 index is out of bounds.");
auto const* data = tensor.data_ptr<T>();
return data[i * tensor.stride(0) + j * tensor.stride(1) + k * tensor.stride(2)];
}
template <typename T>
T* tensorPtr2D(at::Tensor const& tensor, int64_t const row, int64_t const col, char const* tensorName)
{
TORCH_CHECK(tensor.dim() >= 2, tensorName, " must have at least 2 dimensions.");
TORCH_CHECK(row >= 0 && row < tensor.size(0), tensorName, " row is out of bounds.");
TORCH_CHECK(col >= 0 && col < tensor.size(1), tensorName, " column is out of bounds.");
using ValueType = std::remove_const_t<T>;
auto* data = static_cast<ValueType*>(tensor.data_ptr());
return data + row * tensor.stride(0) + col * tensor.stride(1);
}
} // namespace
KvCachePoolMapping readKvCachePoolMapping(at::Tensor const& hostKvCachePoolMapping, int64_t const layerIdx)
{
TORCH_CHECK(hostKvCachePoolMapping.device().is_cpu(), "host_kv_cache_pool_mapping must be a CPU tensor.");
TORCH_CHECK(hostKvCachePoolMapping.dim() == 2, "host_kv_cache_pool_mapping must be a 2D tensor.");
TORCH_CHECK(hostKvCachePoolMapping.size(1) >= 2, "host_kv_cache_pool_mapping must have at least two columns.");
TORCH_CHECK(layerIdx >= 0 && layerIdx < hostKvCachePoolMapping.size(0),
"host_kv_cache_pool_mapping layer index is out of bounds.");
auto const* data = hostKvCachePoolMapping.data_ptr<int32_t>();
auto const rowOffset = layerIdx * hostKvCachePoolMapping.stride(0);
auto const colStride = hostKvCachePoolMapping.stride(1);
KvCachePoolMapping mapping;
mapping.poolIndex = data[rowOffset];
mapping.layerIdxInCachePool = data[rowOffset + colStride];
return mapping;
}
std::optional<at::Tensor> TrtllmAttentionWorkspaceManager::makeWorkspaceView(
at::Tensor const& workspace, int64_t const offset, int64_t const sizeBytes, at::ScalarType const scalarType)
{
if (sizeBytes == 0)
{
return std::nullopt;
}
auto const* workspaceBase = static_cast<uint8_t const*>(workspace.data_ptr());
auto const workspaceSizeBytes = static_cast<int64_t>(workspace.nbytes());
TORCH_CHECK(offset >= 0, "Negative workspace offset is invalid.");
TORCH_CHECK(offset + sizeBytes <= workspaceSizeBytes, "Workspace view exceeds workspace bounds.");
auto const itemSize = static_cast<int64_t>(c10::elementSize(scalarType));
TORCH_CHECK(sizeBytes % itemSize == 0, "Workspace slice is not aligned to dtype size.");
auto options = at::TensorOptions().dtype(scalarType).device(workspace.device());
return torch::from_blob(const_cast<uint8_t*>(workspaceBase) + offset, {sizeBytes / itemSize}, options);
}
TrtllmGenContextWorkspaceLayout TrtllmAttentionWorkspaceManager::buildContextLayout(at::ScalarType const qDtype,
int64_t const batchSize, int64_t const numTokens, int64_t const numHeads, int64_t const headSize,
int64_t const rotaryEmbeddingDim, bool const separateQKvInput, bool const fp8ContextFmha)
{
auto const dtypeSize = static_cast<int64_t>(c10::elementSize(qDtype));
auto const localHiddenUnitsQo = numHeads * headSize;
auto const cuSeqlensSize = static_cast<int64_t>(sizeof(int32_t)) * (batchSize + 1);
auto const rotaryInvFreqSize
= rotaryEmbeddingDim > 0 ? static_cast<int64_t>(sizeof(float)) * batchSize * rotaryEmbeddingDim / 2 : 0;
auto const qBufSize = separateQKvInput ? (fp8ContextFmha ? 1 : dtypeSize) * numTokens * localHiddenUnitsQo : 0;
auto const tokensInfoSize = static_cast<int64_t>(sizeof(int32_t) * 2) * numTokens;
auto const fmhaTileCounterSize = static_cast<int64_t>(sizeof(uint32_t));
auto const fmhaBmm1ScaleSize = fp8ContextFmha ? static_cast<int64_t>(sizeof(float) * 2) : 0;
auto const fmhaBmm2ScaleSize = fp8ContextFmha ? static_cast<int64_t>(sizeof(float)) : 0;
tensorrt_llm::common::op::AttentionContextWorkspaceSizes workspaceSizes{};
workspaceSizes.cuQSeqlens = cuSeqlensSize;
workspaceSizes.cuKvSeqlens = cuSeqlensSize;
workspaceSizes.cuMaskRows = cuSeqlensSize;
workspaceSizes.rotaryInvFreq = rotaryInvFreqSize;
workspaceSizes.qBuf = qBufSize;
workspaceSizes.tokensInfo = tokensInfoSize;
workspaceSizes.fmhaTileCounter = fmhaTileCounterSize;
workspaceSizes.fmhaBmm1Scale = fmhaBmm1ScaleSize;
workspaceSizes.fmhaBmm2Scale = fmhaBmm2ScaleSize;
auto const layout = AttentionWorkspaceManager::buildContextLayout(workspaceSizes, kWorkspaceAlignment);
return TrtllmGenContextWorkspaceLayout{
.trtllmGenWorkspaceOffset = exportOffset(layout.cublasWorkspace),
.cuQSeqlensOffset = exportOffset(layout.cuQSeqlens),
.cuKvSeqlensOffset = exportOffset(layout.cuKvSeqlens),
.cuMaskRowsOffset = exportOffset(layout.cuMaskRows),
.rotaryInvFreqOffset = exportOffset(layout.rotaryInvFreq),
.qBufOffset = exportOffset(layout.qBuf),
.tokensInfoOffset = exportOffset(layout.tokensInfo),
.fmhaTileCounterOffset = exportOffset(layout.fmhaTileCounter),
.fmhaBmm1ScaleOffset = exportOffset(layout.fmhaBmm1Scale),
.fmhaBmm2ScaleOffset = exportOffset(layout.fmhaBmm2Scale),
.trtllmGenWorkspaceSize = kTrtllmGenWorkspaceSize,
.cuSeqlensSize = cuSeqlensSize,
.rotaryInvFreqSize = rotaryInvFreqSize,
.qBufSize = qBufSize,
.tokensInfoSize = tokensInfoSize,
.fmhaTileCounterSize = fmhaTileCounterSize,
.fmhaBmm1ScaleSize = fmhaBmm1ScaleSize,
.fmhaBmm2ScaleSize = fmhaBmm2ScaleSize,
.totalSize = static_cast<int64_t>(layout.totalSize),
.qBufScalarType = fp8ContextFmha ? at::kByte : qDtype,
};
}
TrtllmGenGenerationWorkspaceLayout TrtllmAttentionWorkspaceManager::buildGenerationLayout(at::ScalarType const qDtype,
int64_t const batchBeam, int64_t const numTokens, int64_t const numHeads, int64_t const headSize,
int64_t const rotaryEmbeddingDim, int64_t const numKvHeads, int64_t const maxBlocksPerSequence,
bool const useSparseAttention)
{
auto const dtypeSize = static_cast<int64_t>(c10::elementSize(qDtype));
auto const cuSeqlensSize = static_cast<int64_t>(sizeof(int32_t)) * (batchBeam + 1);
auto const cuKvSeqlensSize = static_cast<int64_t>(sizeof(int32_t)) * (batchBeam + 1);
auto const rotaryInvFreqSize
= rotaryEmbeddingDim > 0 ? static_cast<int64_t>(sizeof(float)) * batchBeam * rotaryEmbeddingDim / 2 : 0;
auto const tokensInfoSize = static_cast<int64_t>(sizeof(int32_t) * 2) * numTokens;
auto const qBufSize = dtypeSize * numTokens * numHeads * headSize;
auto const bmm1ScaleSize = static_cast<int64_t>(sizeof(float) * 2);
auto const bmm2ScaleSize = static_cast<int64_t>(sizeof(float));
auto const sparseAttnCacheSize = useSparseAttention
? static_cast<int64_t>(sizeof(int32_t)) * (batchBeam + batchBeam * 2 * maxBlocksPerSequence) * numKvHeads
: 0;
tensorrt_llm::common::op::AttentionXqaWorkspaceSizes workspaceSizes{};
workspaceSizes.cuSeqlens = cuSeqlensSize;
workspaceSizes.cuKvSeqlens = cuKvSeqlensSize;
workspaceSizes.rotaryInvFreq = rotaryInvFreqSize;
workspaceSizes.tokensInfo = tokensInfoSize;
workspaceSizes.bmm1Scale = bmm1ScaleSize;
workspaceSizes.bmm2Scale = bmm2ScaleSize;
workspaceSizes.sparseAttnCache = sparseAttnCacheSize;
workspaceSizes.kernelWorkspace = qBufSize;
auto const xqaLayout = AttentionWorkspaceManager::buildXqaLayout(workspaceSizes, kWorkspaceAlignment);
auto const trtllmGenWorkspaceOffset = static_cast<int64_t>(xqaLayout.totalSize);
auto const totalSize = xqaLayout.totalSize
+ tensorrt_llm::common::alignSize(static_cast<size_t>(kTrtllmGenWorkspaceSize), kWorkspaceAlignment);
return TrtllmGenGenerationWorkspaceLayout{
.trtllmGenWorkspaceOffset = trtllmGenWorkspaceOffset,
.cuSeqlensOffset = exportOffset(xqaLayout.cuSeqlens),
.cuKvSeqlensOffset = exportOffset(xqaLayout.cuKvSeqlens),
.rotaryInvFreqOffset = exportOffset(xqaLayout.rotaryInvFreq),
.tokensInfoOffset = exportOffset(xqaLayout.tokensInfo),
.qBufOffset = exportOffset(xqaLayout.kernelWorkspace),
.bmm1ScaleOffset = exportOffset(xqaLayout.bmm1Scale),
.bmm2ScaleOffset = exportOffset(xqaLayout.bmm2Scale),
.sparseAttnCacheOffset = exportOffset(xqaLayout.sparseAttnCache),
.trtllmGenWorkspaceSize = kTrtllmGenWorkspaceSize,
.cuSeqlensSize = cuSeqlensSize,
.cuKvSeqlensSize = cuKvSeqlensSize,
.rotaryInvFreqSize = rotaryInvFreqSize,
.tokensInfoSize = tokensInfoSize,
.qBufSize = qBufSize,
.bmm1ScaleSize = bmm1ScaleSize,
.bmm2ScaleSize = bmm2ScaleSize,
.sparseAttnCacheSize = sparseAttnCacheSize,
.totalSize = static_cast<int64_t>(totalSize),
.qBufScalarType = qDtype,
};
}
int64_t TrtllmAttentionWorkspaceManager::getContextWorkspaceSize(at::ScalarType const qDtype, int64_t const batchSize,
int64_t const numTokens, int64_t const numHeads, int64_t const headSize, int64_t const rotaryEmbeddingDim,
bool const separateQKvInput, bool const fp8ContextFmha)
{
return buildContextLayout(
qDtype, batchSize, numTokens, numHeads, headSize, rotaryEmbeddingDim, separateQKvInput, fp8ContextFmha)
.totalSize;
}
int64_t TrtllmAttentionWorkspaceManager::getGenerationWorkspaceSize(at::ScalarType const qDtype,
int64_t const batchBeam, int64_t const numTokens, int64_t const numHeads, int64_t const headSize,
int64_t const rotaryEmbeddingDim, int64_t const numKvHeads, int64_t const maxBlocksPerSequence,
bool const useSparseAttention)
{
return buildGenerationLayout(qDtype, batchBeam, numTokens, numHeads, headSize, rotaryEmbeddingDim, numKvHeads,
maxBlocksPerSequence, useSparseAttention)
.totalSize;
}
TrtllmGenContextWorkspaceViews TrtllmAttentionWorkspaceManager::materializeContextWorkspace(
at::Tensor const& workspace, TrtllmGenContextWorkspaceLayout const& layout)
{
return TrtllmGenContextWorkspaceViews{
.trtllmGenWorkspace
= *makeWorkspaceView(workspace, layout.trtllmGenWorkspaceOffset, layout.trtllmGenWorkspaceSize, at::kByte),
.cuQSeqlens = *makeWorkspaceView(workspace, layout.cuQSeqlensOffset, layout.cuSeqlensSize, at::kInt),
.cuKvSeqlens = *makeWorkspaceView(workspace, layout.cuKvSeqlensOffset, layout.cuSeqlensSize, at::kInt),
.cuMaskRows = *makeWorkspaceView(workspace, layout.cuMaskRowsOffset, layout.cuSeqlensSize, at::kInt),
.rotaryInvFreqBuf
= makeWorkspaceView(workspace, layout.rotaryInvFreqOffset, layout.rotaryInvFreqSize, at::kFloat),
.qBuf = makeWorkspaceView(workspace, layout.qBufOffset, layout.qBufSize, layout.qBufScalarType),
.tokensInfo = *makeWorkspaceView(workspace, layout.tokensInfoOffset, layout.tokensInfoSize, at::kInt),
.fmhaTileCounter
= *makeWorkspaceView(workspace, layout.fmhaTileCounterOffset, layout.fmhaTileCounterSize, at::kUInt32),
.fmhaBmm1Scale = makeWorkspaceView(workspace, layout.fmhaBmm1ScaleOffset, layout.fmhaBmm1ScaleSize, at::kFloat),
.fmhaBmm2Scale = makeWorkspaceView(workspace, layout.fmhaBmm2ScaleOffset, layout.fmhaBmm2ScaleSize, at::kFloat),
};
}
TrtllmGenContextWorkspaceViews TrtllmAttentionWorkspaceManager::materializeContextWorkspace(at::Tensor const& workspace,
at::ScalarType const qDtype, int64_t const batchSize, int64_t const numTokens, int64_t const numHeads,
int64_t const headSize, int64_t const rotaryEmbeddingDim, bool const fp8ContextFmha)
{
auto const layout = buildContextLayout(
qDtype, batchSize, numTokens, numHeads, headSize, rotaryEmbeddingDim, true, fp8ContextFmha);
return materializeContextWorkspace(workspace, layout);
}
TrtllmGenGenerationWorkspaceViews TrtllmAttentionWorkspaceManager::materializeGenerationWorkspace(
at::Tensor const& workspace, TrtllmGenGenerationWorkspaceLayout const& layout)
{
return TrtllmGenGenerationWorkspaceViews{
.trtllmGenWorkspace
= *makeWorkspaceView(workspace, layout.trtllmGenWorkspaceOffset, layout.trtllmGenWorkspaceSize, at::kByte),
.cuSeqlens = *makeWorkspaceView(workspace, layout.cuSeqlensOffset, layout.cuSeqlensSize, at::kInt),
.cuKvSeqlens = *makeWorkspaceView(workspace, layout.cuKvSeqlensOffset, layout.cuKvSeqlensSize, at::kInt),
.rotaryInvFreqBuf
= makeWorkspaceView(workspace, layout.rotaryInvFreqOffset, layout.rotaryInvFreqSize, at::kFloat),
.tokensInfo = *makeWorkspaceView(workspace, layout.tokensInfoOffset, layout.tokensInfoSize, at::kInt),
.qBuf = *makeWorkspaceView(workspace, layout.qBufOffset, layout.qBufSize, layout.qBufScalarType),
.bmm1Scale = *makeWorkspaceView(workspace, layout.bmm1ScaleOffset, layout.bmm1ScaleSize, at::kFloat),
.bmm2Scale = *makeWorkspaceView(workspace, layout.bmm2ScaleOffset, layout.bmm2ScaleSize, at::kFloat),
.sparseAttnCache
= makeWorkspaceView(workspace, layout.sparseAttnCacheOffset, layout.sparseAttnCacheSize, at::kInt),
};
}
TrtllmGenGenerationWorkspaceViews TrtllmAttentionWorkspaceManager::materializeGenerationWorkspace(
at::Tensor const& workspace, at::ScalarType const qDtype, int64_t const batchBeam, int64_t const numTokens,
int64_t const numHeads, int64_t const headSize, int64_t const rotaryEmbeddingDim, int64_t const numKvHeads)
{
auto const layout = buildGenerationLayout(
qDtype, batchBeam, numTokens, numHeads, headSize, rotaryEmbeddingDim, numKvHeads, 0, false);
return materializeGenerationWorkspace(workspace, layout);
}
namespace trtllm::attention
{
using tensorrt_llm::kernels::KVBlockArray;
using tensorrt_llm::kernels::MlaParams;
using tensorrt_llm::kernels::SparseAttentionParams;
using tensorrt_llm::torch_ext::KvCachePoolPointers;
using tensorrt_llm::torch_ext::buildKvCachePoolPointers;
enum class AttentionInputType : int8_t
{
Mixed,
ContextOnly,
GenerationOnly,
};
class RunnerBase
{
public:
int32_t beam_width;
int32_t max_num_requests;
int32_t attention_window_size;
auto data() const
{
return std::make_tuple(beam_width, max_num_requests, attention_window_size);
};
virtual ~RunnerBase() = default;
virtual void prepare(AttentionOp& op) const = 0;
virtual int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size,
int const num_gen_tokens, int const max_blocks_per_sequence, int const ctx_total_kv_len = 0) const
= 0;
// typically, we use single qkv input, but for context MLA, we use separate qkv inputs
virtual void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs,
int32_t const token_offset, int32_t const num_tokens, int32_t const predicted_tokens_per_seq,
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv_or_q,
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, std::optional<int64_t> max_context_q_len_override,
torch::optional<torch::Tensor> kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq,
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> helix_position_offsets, std::optional<torch::Tensor> helix_is_inactive_rank,
torch::optional<torch::Tensor> softmax_stats_tensor,
std::optional<torch::Tensor> spec_decoding_generation_lengths,
std::optional<torch::Tensor> spec_decoding_position_offsets_for_cpp,
std::optional<torch::Tensor> spec_decoding_packed_mask,
std::optional<torch::Tensor> spec_decoding_bl_tree_mask_offset,
std::optional<torch::Tensor> spec_decoding_bl_tree_mask,
std::optional<torch::Tensor> spec_bl_tree_first_sparse_mask_offset_kv,
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
int32_t const num_sparse_topk, std::optional<torch::Tensor> sparse_mla_topk_lens,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata,
std::optional<torch::Tensor> flash_mla_num_splits, bool trtllm_gen_jit_warmup,
std::optional<int64_t> compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
std::optional<torch::Tensor> relative_attention_bias) const
= 0;
};
template <typename T, typename AttentionOutT = T>
class Runner : public RunnerBase
{
public:
void prepare(AttentionOp& op) const override
{
AttentionOp::EnqueueGenerationParams<T> enqueueParams;
enqueueParams.max_attention_window_size = attention_window_size;
enqueueParams.cyclic_attention_window_size = attention_window_size;
enqueueParams.max_cyclic_attention_window_size = attention_window_size;
enqueueParams.beam_width = beam_width;
enqueueParams.num_requests = max_num_requests;
op.prepareEnqueueGeneration<T, KVBlockArray>(enqueueParams);
// Always reserve SemaphoreArray (for multi-block mode) as MMHA may enable multi-block mode when shared memory
// is not enough.
// The attention kernel might split the heads into multiple blocks, so we might need to reserve more semaphores.
// Use mMultiProcessorCount as the lower-bound to make sure we reserve enough semaphores.
op.reserveSemaphoreArray(std::max(op.mNumHeads * max_num_requests, op.getMultiProcessorCount()));
}
int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size,
int const num_gen_tokens, int const max_blocks_per_sequence, int const ctx_total_kv_len = 0) const override
{
size_t const context_workspace_size = op.getWorkspaceSizeForContext(
op.mType, max_num_requests, op.mMaxContextLength, 0, num_tokens, ctx_total_kv_len);
size_t const generation_workspace_size = op.getWorkspaceSizeForGeneration(
op.mType, max_num_requests, max_attention_window_size, num_gen_tokens, max_blocks_per_sequence);
return std::max(context_workspace_size, generation_workspace_size);
}
void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs,
int32_t const token_offset, int32_t const num_tokens, int32_t const predicted_tokens_per_seq,
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv_or_q,
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, std::optional<int64_t> max_context_q_len_override,
torch::optional<torch::Tensor> kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq,
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> helix_position_offsets, std::optional<torch::Tensor> helix_is_inactive_rank,
torch::optional<torch::Tensor> softmax_stats_tensor,
std::optional<torch::Tensor> spec_decoding_generation_lengths,
std::optional<torch::Tensor> spec_decoding_position_offsets_for_cpp,
std::optional<torch::Tensor> spec_decoding_packed_mask,
std::optional<torch::Tensor> spec_decoding_bl_tree_mask_offset,
std::optional<torch::Tensor> spec_decoding_bl_tree_mask,
std::optional<torch::Tensor> spec_bl_tree_first_sparse_mask_offset_kv,
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
int32_t const num_sparse_topk, std::optional<torch::Tensor> sparse_mla_topk_lens,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata,
std::optional<torch::Tensor> flash_mla_num_splits, bool trtllm_gen_jit_warmup,
std::optional<int64_t> compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
std::optional<torch::Tensor> relative_attention_bias) const override
{
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
T* k_ptr = nullptr;
T* v_ptr = nullptr;
AttentionOutT* context_buf = static_cast<AttentionOutT*>(output.slice(0, token_offset).data_ptr());
TORCH_CHECK(!op.mFuseFp4Quant || output_sf.has_value());
void* context_buf_sf = op.mFuseFp4Quant ? output_sf->data_ptr() : nullptr;
// Rotary inv_freq, cos_sin cache to avoid re-computing.
float const* rotary_inv_freq_ptr = nullptr;
float2 const* rotary_cos_sin_ptr = nullptr;
if (op.isRoPE())
{
if (rotary_inv_freq.has_value())
{
rotary_inv_freq_ptr = rotary_inv_freq.value().data_ptr<float>();
}
if (rotary_cos_sin.has_value())
{
rotary_cos_sin_ptr = static_cast<float2 const*>(rotary_cos_sin.value().data_ptr());
}
}
void* workspace_ptr = workspace.data_ptr();
[[maybe_unused]] MlaParams<T> mla_params;
if (op.isMLAEnabled())
{
if (is_context && op.mUseSparseAttention)
{
if (latent_cache.has_value())
{
mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr());
}
else
{
// kv cache reuse / chunked context cases, latent_cache is not used
mla_params.latent_cache = nullptr;
}
TORCH_CHECK(q_pe.has_value());
TORCH_CHECK(q_pe->dim() == 3);
TORCH_CHECK(q_pe->strides()[2] == 1);
mla_params.q_pe = static_cast<T*>(q_pe->data_ptr());
mla_params.q_pe_ld = q_pe->strides()[1];
mla_params.q_pe_stride = q_pe->strides()[0];
}
else if (is_context)
{
if (latent_cache.has_value())
{
mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr());
}
else
{
// kv cache reuse / chunked context cases, latent_cache is not used
mla_params.latent_cache = nullptr;
}
TORCH_CHECK(k.has_value());
TORCH_CHECK(v.has_value());
TORCH_CHECK(k->dim() == 2);
TORCH_CHECK(v->dim() == 2);
TORCH_CHECK(k->strides()[1] == 1);
TORCH_CHECK(v->strides()[1] == 1);
k_ptr = static_cast<T*>(k->slice(0, token_offset).data_ptr());
v_ptr = static_cast<T*>(v->slice(0, token_offset).data_ptr());
mla_params.k_buf = k_ptr;
mla_params.v_buf = v_ptr;
// For generation, helix position is in ropeOp
if (helix_position_offsets.has_value())
{
mla_params.helix_position_offsets = helix_position_offsets->data_ptr<int32_t>();
}
if (helix_is_inactive_rank.has_value())
{
mla_params.helix_is_inactive_rank = helix_is_inactive_rank->data_ptr<bool>();
}
}
else
{
TORCH_CHECK(latent_cache.has_value());
mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr());
TORCH_CHECK(q_pe.has_value());
TORCH_CHECK(q_pe->dim() == 3);
TORCH_CHECK(q_pe->strides()[2] == 1);
mla_params.q_pe = static_cast<T*>(q_pe->data_ptr());
mla_params.q_pe_ld = q_pe->strides()[1];
mla_params.q_pe_stride = q_pe->strides()[0];
mla_params.seqQOffset
= cu_q_seqlens.has_value() ? reinterpret_cast<int*>(cu_q_seqlens.value().data_ptr()) : nullptr;
mla_params.cu_kv_seqlens
= cu_kv_seqlens.has_value() ? reinterpret_cast<int*>(cu_kv_seqlens.value().data_ptr()) : nullptr;
mla_params.fmha_tile_counter = fmha_scheduler_counter.has_value()
? reinterpret_cast<uint32_t*>(fmha_scheduler_counter.value().data_ptr())
: nullptr;
mla_params.bmm1_scale = mla_bmm1_scale.has_value()
? reinterpret_cast<float*>(mla_bmm1_scale.value().data_ptr())
: nullptr;
mla_params.bmm2_scale = mla_bmm2_scale.has_value()
? reinterpret_cast<float*>(mla_bmm2_scale.value().data_ptr())
: nullptr;
mla_params.quant_q_buf
= quant_q_buffer.has_value() ? reinterpret_cast<void*>(quant_q_buffer.value().data_ptr()) : nullptr;
}
mla_params.q_buf = attention_input;
mla_params.context_buf = reinterpret_cast<T*>(context_buf);
mla_params.cos_sin_cache = rotary_cos_sin_ptr;
mla_params.batch_size = num_seqs;
mla_params.acc_q_len = num_tokens;
mla_params.head_num = op.mNumHeads;
mla_params.meta = op.mMLAParams;
mla_params.workspace = workspace_ptr;
}
// Extract K/V pointers for sage attention (separate Q/K/V inputs).
else if (is_context
&& (op.mSageAttnNumEltsPerBlkQ > 0 || op.mSageAttnNumEltsPerBlkK > 0 || op.mSageAttnNumEltsPerBlkV > 0))
{
TORCH_CHECK(k.has_value() && v.has_value(), "SageAttention demands separate K and V buffers");
k_ptr = static_cast<T*>(k->slice(0, token_offset).data_ptr());
v_ptr = static_cast<T*>(v->slice(0, token_offset).data_ptr());
}
int const* context_lengths_ptr = context_lengths.slice(0, seq_offset).data_ptr<int>();
int const* sequence_lengths_ptr = sequence_length.slice(0, seq_offset).data_ptr<int>();
// Note we still need context length during generation for MMHA optimization.
// For encoder CUDA graphs compatibility, allow the caller to override the
// max context Q length so FMHA kernel launch params (mMaxSeqLenQ-driven grid
// and cluster dims) are stable across graph replays even when actual per-batch
// sequence lengths vary.
int32_t const max_context_q_len_computed
= host_context_lengths.slice(0, seq_offset, seq_offset + num_seqs).max().item<int32_t>();
int32_t const max_past_kv_length_computed
= host_past_key_value_lengths.slice(0, seq_offset, seq_offset + num_seqs).max().item<int32_t>();
if (max_context_q_len_override.has_value())
{
int32_t const override_value = static_cast<int32_t>(max_context_q_len_override.value());
TORCH_CHECK(override_value >= max_context_q_len_computed,
"max_context_q_len_override (%d) must be >= computed max context q length (%d).", override_value,
max_context_q_len_computed);
TORCH_CHECK(override_value >= max_past_kv_length_computed,
"max_context_q_len_override (%d) must be >= computed max past kv length (%d).", override_value,
max_past_kv_length_computed);
}
int32_t const max_context_q_len = max_context_q_len_override.has_value()
? static_cast<int32_t>(max_context_q_len_override.value())
: max_context_q_len_computed;
// Override the max_past_kv_length as well for encoder CUDA graph compatibility
int32_t const max_past_kv_length = max_context_q_len_override.has_value()
? static_cast<int32_t>(max_context_q_len_override.value())
: max_past_kv_length_computed;
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
// unless each layer has different attention window sizes.
int const max_attention_window_size = beam_width == 1 ? attention_window_size
: cache_indirection.has_value() ? cache_indirection.value().size(2)
: attention_window_size;
// The cyclic_attention_window_size will determine the cyclic kv cache position of new tokens.
// Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity.
int const cyclic_attention_window_size = attention_window_size;
bool const can_use_one_more_block = beam_width > 1;
int max_blocks_per_sequence = 0;
int32_t pool_index = 0;
int32_t layer_idx_in_cache_pool = 0;
KVBlockArray::DataType* block_offsets = nullptr;
bool use_kv_cache = false;
KvCachePoolPointers pool_pointers;
max_blocks_per_sequence
= op.useKVCache() && kv_cache_block_offsets.has_value() ? kv_cache_block_offsets.value().size(-1) : 0;
pool_index = op.useKVCache() && host_kv_cache_pool_mapping.has_value()
? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 0}).item<int32_t>()
: 0;
layer_idx_in_cache_pool = op.useKVCache() && host_kv_cache_pool_mapping.has_value()
? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 1}).item<int32_t>()
: 0;
block_offsets = static_cast<KVBlockArray::DataType*>(op.useKVCache() && kv_cache_block_offsets.has_value()
? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr()
: nullptr);
// The cache element size in bits.
int cache_elem_bits = op.getKvCacheElemSizeInBits<T>();
auto const block_size = op.mTokensPerBlock * op.mNumKVHeads * op.mHeadSize;
auto const bytes_per_block = block_size * cache_elem_bits / 8 /*bits*/;
int32_t const kv_factor = op.isMLAEnabled() ? 1 : 2;
auto const intra_pool_offset = layer_idx_in_cache_pool * kv_factor * bytes_per_block;
// Build KV cache pool pointers from the host tensor.
use_kv_cache = op.useKVCache() && host_kv_cache_pool_pointers.has_value();
if (use_kv_cache)
{
pool_pointers = buildKvCachePoolPointers(host_kv_cache_pool_pointers.value(), pool_index, intra_pool_offset,
block_size, layer_idx_in_cache_pool, kv_factor, op.mKVCacheQuantMode.hasFp4KvCache());
}
float const* kv_scale_orig_quant_ptr = nullptr;
float const* kv_scale_quant_orig_ptr = nullptr;
if (op.mKVCacheQuantMode.hasKvCacheQuant() && kv_scale_orig_quant.has_value()
&& kv_scale_quant_orig.has_value())
{
kv_scale_orig_quant_ptr = kv_scale_orig_quant.value().data_ptr<float>();
kv_scale_quant_orig_ptr = kv_scale_quant_orig.value().data_ptr<float>();
if (op.mKVCacheQuantMode.hasFp4KvCache())
{
TORCH_CHECK(kv_scale_orig_quant.value().size(0) == 3);
TORCH_CHECK(kv_scale_quant_orig.value().size(0) == 3);
}
}
// For FP8 output, out_scale represents the output scale.
float const* out_scale_ptr = (op.mFP8ContextFMHA && !op.mFuseFp4Quant && out_scale.has_value())
? out_scale.value().data_ptr<float>()
: nullptr;
// For NVFP4 output, out_scale holds the global scale for scaling factors.
float const* out_sf_scale_ptr
= op.mFuseFp4Quant && out_scale.has_value() ? out_scale.value().data_ptr<float>() : nullptr;
// The attention_sinks is a float tensor with shape [num_heads_q]
float const* attention_sinks_ptr = nullptr;
if (attention_sinks.has_value())
{
TORCH_CHECK(
attention_sinks.value().dtype() == torch::kFloat32, "Expected attention_sinks to have float dtype");
attention_sinks_ptr = attention_sinks.value().data_ptr<float>();
}
T const* relative_attention_bias_ptr = nullptr;
int relative_attention_bias_stride = 0;
if (relative_attention_bias.has_value())
{
auto const& relative_attention_bias_tensor = relative_attention_bias.value();
TORCH_CHECK(relative_attention_bias_tensor.dim() == 2 || relative_attention_bias_tensor.dim() == 3,
"relative_attention_bias must be [num_heads, num_buckets] for implicit mode or "
"[num_heads, max_seq_len, max_seq_len] for explicit mode");
TORCH_CHECK(relative_attention_bias_tensor.is_contiguous(), "relative_attention_bias must be contiguous");
TORCH_CHECK(relative_attention_bias_tensor.scalar_type() == qkv_or_q.scalar_type(),
"relative_attention_bias dtype must match attention input dtype");
relative_attention_bias_ptr = static_cast<T const*>(relative_attention_bias_tensor.data_ptr());
relative_attention_bias_stride = static_cast<int>(relative_attention_bias_tensor.size(1));
}
// Prepare sparse attention parameters
op.mRuntimeSparseAttentionParams.sparse_kv_indices
= sparse_kv_indices.has_value() ? sparse_kv_indices.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_kv_offsets
= sparse_kv_offsets.has_value() ? sparse_kv_offsets.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_attn_indices
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_attn_indices_block_size = sparse_attn_indices_block_size;
op.mRuntimeSparseAttentionParams.sparse_attn_indices_stride
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().size(-1) : 0;
op.mRuntimeSparseAttentionParams.num_sparse_topk = num_sparse_topk;
op.mRuntimeSparseAttentionParams.sparse_mla_topk_lens
= sparse_mla_topk_lens.has_value() ? sparse_mla_topk_lens.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_kv_cache_pool = nullptr;
op.mRuntimeSparseAttentionParams.sliding_window_kv_cache_pool = nullptr;
if (op.mUseSparseAttention && use_kv_cache)
{
if (host_kv_cache_pool_pointers.has_value())
{
auto* kvCachePool = reinterpret_cast<char*>(
host_kv_cache_pool_pointers.value().index({pool_index, 0}).item<int64_t>());
if (sparse_mla_topk_lens.has_value())
{
// Deepseek V4 dynamic sparse MLA always uses the SWA pool for now.
op.mRuntimeSparseAttentionParams.sliding_window_kv_cache_pool = kvCachePool;
if (compressed_kv_cache_pool_ptr.has_value())
{
op.mRuntimeSparseAttentionParams.sparse_kv_cache_pool
= reinterpret_cast<char*>(compressed_kv_cache_pool_ptr.value());
}
}
else
{
op.mRuntimeSparseAttentionParams.sparse_kv_cache_pool = kvCachePool;
}
}
}
AttentionOp::EnqueueParams<T> common_enqueue_params;
common_enqueue_params.attention_input = attention_input;
common_enqueue_params.attention_sinks = attention_sinks_ptr;
common_enqueue_params.rotary_inv_freq = rotary_inv_freq_ptr;
common_enqueue_params.rotary_cos_sin = rotary_cos_sin_ptr;
common_enqueue_params.relative_attention_bias = relative_attention_bias_ptr;
common_enqueue_params.relative_attention_bias_stride = relative_attention_bias_stride;
common_enqueue_params.max_past_kv_length = max_past_kv_length;
common_enqueue_params.max_attention_window_size = max_attention_window_size;
common_enqueue_params.cyclic_attention_window_size = cyclic_attention_window_size;
common_enqueue_params.max_cyclic_attention_window_size = cyclic_attention_window_size;
common_enqueue_params.can_use_one_more_block = can_use_one_more_block;
common_enqueue_params.kv_scale_orig_quant = kv_scale_orig_quant_ptr;
common_enqueue_params.kv_scale_quant_orig = kv_scale_quant_orig_ptr;
common_enqueue_params.attention_output_orig_quant = out_scale_ptr;
common_enqueue_params.attention_output_sf_scale = out_sf_scale_ptr;
common_enqueue_params.context_buf = context_buf;
common_enqueue_params.context_buf_sf = context_buf_sf;
common_enqueue_params.block_offsets = block_offsets;
common_enqueue_params.host_primary_pool_pointer = pool_pointers.primaryPoolPtr;
common_enqueue_params.host_secondary_pool_pointer = pool_pointers.secondaryPoolPtr;
common_enqueue_params.host_primary_block_scale_pool_pointer = pool_pointers.primaryBlockScalePoolPtr;
common_enqueue_params.host_secondary_block_scale_pool_pointer = pool_pointers.secondaryBlockScalePoolPtr;
common_enqueue_params.num_tokens = num_tokens;
common_enqueue_params.total_kv_len = total_kv_len;
common_enqueue_params.max_blocks_per_sequence = max_blocks_per_sequence;
common_enqueue_params.sequence_lengths = sequence_lengths_ptr;
common_enqueue_params.context_lengths = context_lengths_ptr;
common_enqueue_params.host_context_lengths = host_context_lengths.data_ptr<int32_t>();
common_enqueue_params.workspace = workspace_ptr;
common_enqueue_params.trtllm_gen_jit_warmup = trtllm_gen_jit_warmup;
if (is_cross)
{
// For cross attention, the KV (encoder) sequence lengths are passed in via
// `sequence_length` (already sliced into `sequence_lengths_ptr`), so reuse
// it directly instead of a redundant `encoder_input_lengths` tensor.
common_enqueue_params.encoder_input_lengths = sequence_lengths_ptr;
}
if (softmax_stats_tensor.has_value())
{
TLLM_CHECK_WITH_INFO(softmax_stats_tensor.value().scalar_type() == at::ScalarType::Float,
"softmax_stats_tensor must have float type");
TLLM_CHECK_WITH_INFO(softmax_stats_tensor.value().size(0) >= num_tokens,
"softmax_stats_tensor must have first dimension >= num_tokens");
TLLM_CHECK_WITH_INFO(softmax_stats_tensor.value().size(1) >= op.mNumHeads,
"softmax_stats_tensor must have second dimension >= num_heads");
TLLM_CHECK_WITH_INFO(
softmax_stats_tensor.value().size(2) == 2, "softmax_stats_tensor must have third dimension == 2");
common_enqueue_params.softmax_stats = static_cast<float2*>(softmax_stats_tensor.value().data_ptr());
}
// Shared helper to wire helix params into the enqueue params.
// Works for both EnqueueContextParams and EnqueueGenerationParams since both have
// helix_position_offsets and helix_is_inactive_rank fields.
auto const extractHelixParams = [&helix_position_offsets, &helix_is_inactive_rank](auto& params)
{
if (helix_position_offsets.has_value())
{
params.helix_position_offsets = helix_position_offsets->data_ptr<int32_t>();
}
if (helix_is_inactive_rank.has_value())
{
params.helix_is_inactive_rank = helix_is_inactive_rank->data_ptr<bool>();
}
};
if (is_context) // context stage
{
common_enqueue_params.input_seq_length = max_context_q_len;
AttentionOp::EnqueueContextParams<T> enqueue_params{common_enqueue_params};
enqueue_params.batch_size = num_seqs;
enqueue_params.k_ptr = k_ptr;
enqueue_params.v_ptr = v_ptr;
if (cu_q_seqlens.has_value())
{
TORCH_CHECK(cu_q_seqlens->dim() == 1, "cu_q_seqlens must be a 1-D tensor.");
TORCH_CHECK(cu_q_seqlens->is_cuda(), "cu_q_seqlens must be a CUDA tensor.");
TORCH_CHECK(cu_q_seqlens->scalar_type() == at::ScalarType::Int, "cu_q_seqlens must be int32.");
TORCH_CHECK(
cu_q_seqlens->size(0) >= num_seqs + 1, "cu_q_seqlens must have at least num_seqs + 1 elements.");
enqueue_params.cu_q_seqlens = cu_q_seqlens->data_ptr<int32_t>();
}
if (cu_kv_seqlens.has_value())
{
TORCH_CHECK(cu_kv_seqlens->dim() == 1, "cu_kv_seqlens must be a 1-D tensor.");
TORCH_CHECK(cu_kv_seqlens->is_cuda(), "cu_kv_seqlens must be a CUDA tensor.");
TORCH_CHECK(cu_kv_seqlens->scalar_type() == at::ScalarType::Int, "cu_kv_seqlens must be int32.");
TORCH_CHECK(
cu_kv_seqlens->size(0) >= num_seqs + 1, "cu_kv_seqlens must have at least num_seqs + 1 elements.");
enqueue_params.cu_kv_seqlens = cu_kv_seqlens->data_ptr<int32_t>();
}
// Pass V's actual token stride so the FMHA runner handles both
// contiguous V (AutoDeploy) and non-contiguous V (PyTorch backend
// kv.split() view) correctly.
if (v_ptr != nullptr && v.has_value())
{
enqueue_params.v_stride_in_bytes = v->strides()[0] * v->element_size();
}
if (is_cross && cross_kv.has_value())
{
auto const& cross_kv_tensor = cross_kv.value();
enqueue_params.cross_kv = static_cast<T const*>(cross_kv_tensor.data_ptr());
enqueue_params.num_encoder_tokens = static_cast<int32_t>(cross_kv_tensor.size(0));
enqueue_params.cross_kv_length
= host_past_key_value_lengths.slice(0, seq_offset, seq_offset + num_seqs).max().item<int32_t>();
}
if (op.isMLAEnabled())
{
mla_params.cache_seq_lens = sequence_lengths_ptr;
mla_params.max_input_seq_len = max_context_q_len;
enqueue_params.mla_param = &mla_params;
}
if (op.isMRoPE() && mrope_rotary_cos_sin.has_value())
{
enqueue_params.mrope_rotary_cos_sin
= static_cast<float2 const*>(mrope_rotary_cos_sin.value().data_ptr());
}
extractHelixParams(enqueue_params);
op.enqueueContext<T, KVBlockArray>(enqueue_params, stream);
}
else // generation stage
{
int32_t const batch_beam = num_seqs;
TLLM_CHECK(batch_beam % beam_width == 0);
int32_t const num_requests = batch_beam / beam_width;
TLLM_CHECK_WITH_INFO(num_tokens % num_seqs == 0,
"seq_len should be same for all generation requests, num_tokens=%d, num_seqs=%d", num_tokens, num_seqs);
int32_t const input_seq_length = num_tokens / num_seqs;
common_enqueue_params.input_seq_length = input_seq_length;
AttentionOp::EnqueueGenerationParams<T> enqueue_params{common_enqueue_params};
enqueue_params.layer_idx = op.mLayerIdx;
enqueue_params.beam_width = beam_width;
enqueue_params.num_requests = num_requests;
enqueue_params.cache_indir = beam_width == 1
? nullptr
: (cache_indirection.has_value() ? cache_indirection.value().data_ptr<int32_t>() : nullptr);
enqueue_params.semaphores = op.multiBlockSemaphores();
enqueue_params.host_past_key_value_lengths = host_past_key_value_lengths.data_ptr<int32_t>();
enqueue_params.start_token_idx_sf = token_offset;
if (op.isMRoPE() && mrope_position_deltas.has_value())
{
enqueue_params.mrope_position_deltas = mrope_position_deltas.value().data_ptr<int32_t>();
}
if (op.mIsSpecDecodingEnabled && op.mUseSpecDecoding)
{
bool useTllmGen = tensorrt_llm::common::isSM100Family();
TORCH_CHECK(spec_decoding_generation_lengths.has_value(),
"Expecting spec_decoding_generation_lengths in spec-dec mode.");
TORCH_CHECK(spec_decoding_position_offsets_for_cpp.has_value(),
"Expecting spec_decoding_position_offsets_for_cpp in spec-dec mode.");
TORCH_CHECK(
spec_decoding_packed_mask.has_value(), "Expecting spec_decoding_packed_mask in spec-dec mode.");
if (useTllmGen)
{
TORCH_CHECK(spec_decoding_bl_tree_mask_offset.has_value(),
"Expecting spec_decoding_bl_tree_mask_offset in trtllm-gen spec-dec mode.");
TORCH_CHECK(spec_decoding_bl_tree_mask.has_value(),
"Expecting spec_decoding_bl_tree_mask in trtllm-gen spec-dec mode.");
TORCH_CHECK(spec_bl_tree_first_sparse_mask_offset_kv.has_value(),
"Expecting spec_bl_tree_first_sparse_mask_offset_kv in trtllm-gen spec-dec mode.");
enqueue_params.spec_decoding_bl_tree_mask_offset
= spec_decoding_bl_tree_mask_offset->data_ptr<int64_t>();
enqueue_params.spec_decoding_bl_tree_mask = spec_decoding_bl_tree_mask->data_ptr<uint32_t>();
enqueue_params.spec_bl_tree_first_sparse_mask_offset_kv
= spec_bl_tree_first_sparse_mask_offset_kv->data_ptr<int32_t>();
}
enqueue_params.spec_decoding_generation_lengths = spec_decoding_generation_lengths->data_ptr<int32_t>();
enqueue_params.spec_decoding_position_offsets
= spec_decoding_position_offsets_for_cpp->data_ptr<int32_t>();
enqueue_params.spec_decoding_packed_mask = spec_decoding_packed_mask->data_ptr<int32_t>();
enqueue_params.spec_decoding_is_generation_length_variable = true;
TLLM_CHECK(spec_decoding_position_offsets_for_cpp->dim() == 2); // [batch_size, max_draft_len + 1]
if (useTllmGen)
{
// Blackwell uses the padded packed-mask row dim as the mask stride.
TLLM_CHECK(spec_decoding_packed_mask->dim() == 3);
enqueue_params.spec_decoding_max_generation_length = spec_decoding_packed_mask->sizes()[1];
}
else
{
enqueue_params.spec_decoding_max_generation_length
= spec_decoding_position_offsets_for_cpp->sizes()[1];
}
}
// Current mlaGeneration will using fmha to do attention, so we don't go into enqueueGeneration
if (op.isMLAEnabled())
{
if (op.mUseGenFlashMLA == true)
{
TORCH_CHECK(block_ids_per_seq.has_value());
int const* block_ids_per_seq_ptr = static_cast<int*>(block_ids_per_seq->data_ptr());
mla_params.block_ids_per_seq = block_ids_per_seq_ptr;
// Use pre-computed metadata if provided.
if (flash_mla_tile_scheduler_metadata.has_value())
{
TORCH_CHECK(flash_mla_num_splits.has_value(),
"flash_mla_num_splits must be provided when flash_mla_tile_scheduler_metadata is set.");
mla_params.flash_mla_tile_scheduler_metadata
= flash_mla_tile_scheduler_metadata->data_ptr<int>();
mla_params.flash_mla_num_splits = flash_mla_num_splits->data_ptr<int>();
}
}
mla_params.cache_seq_lens = sequence_lengths_ptr;
{
op.mlaGeneration<T>(mla_params, enqueue_params, stream);
}
}
else
{
extractHelixParams(enqueue_params);
{
op.enqueueGeneration<T, KVBlockArray>(enqueue_params, stream);
}
}
{
std::string const afterGenStr = "gen attention at layer " + std::to_string(op.mLayerIdx);
{
TLLM_CHECK_DEBUG_WITH_INFO(tensorrt_llm::runtime::utils::tensorHasInvalid(num_tokens,
output.size(1), op.mType, context_buf, stream, afterGenStr)
== false,
"Found invalid number (NaN or Inf) in " + afterGenStr);
}
}
}
sync_check_cuda_error(stream);
}
};
template class Runner<float>;
template class Runner<half>;
template class Runner<half, __nv_fp8_e4m3>;
#ifdef ENABLE_BF16
template class Runner<__nv_bfloat16>;
template class Runner<__nv_bfloat16, __nv_fp8_e4m3>;
#endif
} // namespace trtllm::attention
using RunnerPtr = std::shared_ptr<torch_ext::trtllm::attention::RunnerBase>;
using torch_ext::trtllm::attention::Runner;
using torch_ext::trtllm::attention::AttentionInputType;
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
std::optional<int64_t> max_context_q_len_override, std::optional<torch::Tensor> kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
int64_t const local_layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const max_seq_len, int64_t const attention_window_size, int64_t const beam_width, int64_t const mask_type,
int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type, int64_t const rope_dim,