Skip to content

Commit 0a19205

Browse files
authored
[None][refactor] Flatten thop.attention sequence kwargs + rename rotary_embedding_* to rope_* (#14569)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent b1dfd30 commit 0a19205

8 files changed

Lines changed: 300 additions & 293 deletions

File tree

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,20 @@ void initBindings(nb::module_& m)
144144
nb::arg("num_heads"), nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block").none(),
145145
nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"),
146146
nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"), nb::arg("q_scaling"),
147-
nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"), nb::arg("rotary_embedding_base"),
148-
nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
149-
nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"),
147+
nb::arg("position_embedding_type"), nb::arg("rope_dim"), nb::arg("rope_base"), nb::arg("rope_scale_type"),
148+
nb::arg("rope_scale"), nb::arg("rope_short_m_scale"), nb::arg("rope_long_m_scale"),
149+
nb::arg("rope_max_positions"), nb::arg("rope_original_max_positions"), nb::arg("use_paged_context_fmha"),
150150
nb::arg("attention_input_type").none(), nb::arg("is_mla_enable"),
151151
nb::arg("chunked_prefill_buffer_batch_size").none(), nb::arg("q_lora_rank").none(),
152152
nb::arg("kv_lora_rank").none(), nb::arg("qk_nope_head_dim").none(), nb::arg("qk_rope_head_dim").none(),
153153
nb::arg("v_head_dim").none(), nb::arg("rope_append").none(), nb::arg("mrope_rotary_cos_sin").none(),
154-
nb::arg("mrope_position_deltas").none(), nb::arg("helix_tensor_params"), nb::arg("attention_chunk_size").none(),
155-
nb::arg("softmax_stats_tensor").none(), nb::arg("spec_decoding_bool_params"),
156-
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices").none(),
154+
nb::arg("mrope_position_deltas").none(), nb::arg("helix_position_offsets").none(),
155+
nb::arg("helix_is_inactive_rank").none(), nb::arg("attention_chunk_size").none(),
156+
nb::arg("softmax_stats_tensor").none(), nb::arg("is_spec_decoding_enabled"), nb::arg("use_spec_decoding"),
157+
nb::arg("is_spec_dec_tree"), nb::arg("spec_decoding_generation_lengths").none(),
158+
nb::arg("spec_decoding_position_offsets_for_cpp").none(), nb::arg("spec_decoding_packed_mask").none(),
159+
nb::arg("spec_decoding_bl_tree_mask_offset").none(), nb::arg("spec_decoding_bl_tree_mask").none(),
160+
nb::arg("spec_bl_tree_first_sparse_mask_offset_kv").none(), nb::arg("sparse_kv_indices").none(),
157161
nb::arg("sparse_kv_offsets").none(), nb::arg("sparse_attn_indices").none(),
158162
nb::arg("sparse_attn_offsets").none(), nb::arg("sparse_attn_indices_block_size"),
159163
nb::arg("num_sparse_topk") = std::nullopt, nb::arg("sparse_mla_topk_lens") = std::nullopt,

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 82 additions & 97 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/thop/attentionOp.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,23 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
5959
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
6060
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
6161
int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode, double const q_scaling,
62-
int64_t const position_embedding_type, int64_t const rotary_embedding_dim, double const rotary_embedding_base,
63-
int64_t const rotary_embedding_scale_type, std::vector<double> rotary_embedding_scales,
64-
std::vector<int64_t> rotary_embedding_max_position_info, bool const use_paged_context_fmha,
65-
std::optional<int64_t> attention_input_type, bool is_mla_enable,
62+
int64_t const position_embedding_type, int64_t const rope_dim, double const rope_base,
63+
int64_t const rope_scale_type, double const rope_scale, double const rope_short_m_scale,
64+
double const rope_long_m_scale, int64_t const rope_max_positions, int64_t const rope_original_max_positions,
65+
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
6666
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
6767
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
6868
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim, std::optional<bool> rope_append,
6969
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
70-
std::vector<std::optional<torch::Tensor>> helix_tensor_params, std::optional<int64_t> attention_chunk_size,
71-
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
72-
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
70+
std::optional<torch::Tensor> helix_position_offsets, std::optional<torch::Tensor> helix_is_inactive_rank,
71+
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
72+
bool const is_spec_decoding_enabled, bool const use_spec_decoding, bool const is_spec_dec_tree,
73+
std::optional<torch::Tensor> spec_decoding_generation_lengths,
74+
std::optional<torch::Tensor> spec_decoding_position_offsets_for_cpp,
75+
std::optional<torch::Tensor> spec_decoding_packed_mask,
76+
std::optional<torch::Tensor> spec_decoding_bl_tree_mask_offset,
77+
std::optional<torch::Tensor> spec_decoding_bl_tree_mask,
78+
std::optional<torch::Tensor> spec_bl_tree_first_sparse_mask_offset_kv,
7379
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
7480
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
7581
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> num_sparse_topk,

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 42 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,6 @@ def effective_workspace(self) -> Optional[torch.Tensor]:
157157
"""Attention-kernel workspace, switching to the CUDA-graph copy under capture."""
158158
return self.cuda_graph_workspace if self.is_cuda_graph else self.workspace
159159

160-
@property
161-
def helix_tensor_params(self) -> List[Optional[torch.Tensor]]:
162-
"""``[helix_position_offsets, helix_is_inactive_rank]`` — the positional
163-
helix tensor list expected by the C++ attention op."""
164-
return [self.helix_position_offsets, self.helix_is_inactive_rank]
165-
166-
@property
167-
def spec_decoding_bool_params(self) -> List[bool]:
168-
"""``[is_spec_decoding_enabled, use_spec_decoding, is_spec_dec_tree]`` —
169-
the positional bool list expected by the C++ attention op."""
170-
return [
171-
self.is_spec_decoding_enabled,
172-
self.use_spec_decoding,
173-
self.is_spec_dec_tree,
174-
]
175-
176160
@property
177161
def spec_decoding_position_offsets_for_cpp(self) -> Optional[torch.Tensor]:
178162
"""``spec_decoding_position_offsets`` reshaped to the 2D layout the C++
@@ -1051,22 +1035,6 @@ def generate_spec_decoding_generation_length(self, runtime_draft_len):
10511035
def is_sm_version_trtllm_gen_kernel(self, sm):
10521036
return not (sm < 100 or sm in [120, 121])
10531037

1054-
@property
1055-
def spec_decoding_tensor_params(self) -> List[Optional[torch.Tensor]]:
1056-
"""Positional spec-decoding tensor list for the C++ attention op.
1057-
Includes three Blackwell-tree mask tensors on SM versions that take
1058-
the trtllm-gen kernel."""
1059-
params = [
1060-
self.spec_decoding_generation_lengths,
1061-
self.spec_decoding_position_offsets_for_cpp,
1062-
self.spec_decoding_packed_mask,
1063-
]
1064-
if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
1065-
params.append(self.spec_decoding_bl_tree_mask_offset)
1066-
params.append(self.spec_decoding_bl_tree_mask)
1067-
params.append(self.spec_bl_tree_first_sparse_mask_offset_kv)
1068-
return params
1069-
10701038

10711039
class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
10721040

@@ -1332,35 +1300,36 @@ def create_output(self, q, *, is_quantize_output: bool,
13321300
]
13331301

13341302
@property
1335-
def rotary_embedding_dim(self) -> int:
1303+
def rope_dim(self) -> int:
13361304
return self.rope_params.dim
13371305

13381306
@property
1339-
def rotary_embedding_base(self) -> float:
1307+
def rope_base(self) -> float:
13401308
return self.rope_params.theta
13411309

13421310
@property
1343-
def rotary_embedding_scale_type(self) -> int:
1311+
def rope_scale_type(self) -> int:
13441312
return int(self.rope_params.scale_type)
13451313

13461314
@property
1347-
def rotary_embedding_scales(self) -> List[float]:
1348-
"""``[scale, short_m_scale, long_m_scale]`` — the positional RoPE-scale
1349-
list expected by the C++ attention op."""
1350-
return [
1351-
self.rope_params.scale,
1352-
self.rope_params.short_m_scale,
1353-
self.rope_params.long_m_scale,
1354-
]
1315+
def rope_scale(self) -> float:
1316+
return self.rope_params.scale
13551317

13561318
@property
1357-
def rotary_embedding_max_position_info(self) -> List[int]:
1358-
"""``[max_positions, original_max_positions]`` — the positional
1359-
RoPE-positions list expected by the C++ attention op."""
1360-
return [
1361-
self.rope_params.max_positions,
1362-
self.rope_params.original_max_positions,
1363-
]
1319+
def rope_short_m_scale(self) -> float:
1320+
return self.rope_params.short_m_scale
1321+
1322+
@property
1323+
def rope_long_m_scale(self) -> float:
1324+
return self.rope_params.long_m_scale
1325+
1326+
@property
1327+
def rope_max_positions(self) -> int:
1328+
return self.rope_params.max_positions
1329+
1330+
@property
1331+
def rope_original_max_positions(self) -> int:
1332+
return self.rope_params.original_max_positions
13641333

13651334
@property
13661335
def skip_softmax_threshold_scale_factor_prefill(self) -> Optional[float]:
@@ -1530,10 +1499,21 @@ def _run(
15301499
max_num_requests=metadata.max_num_requests,
15311500
beam_width=metadata.beam_width,
15321501
use_paged_context_fmha=metadata.use_paged_context_fmha,
1533-
helix_tensor_params=metadata.helix_tensor_params,
1534-
spec_decoding_bool_params=metadata.spec_decoding_bool_params,
1535-
spec_decoding_tensor_params=metadata.
1536-
spec_decoding_tensor_params,
1502+
helix_position_offsets=metadata.helix_position_offsets,
1503+
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
1504+
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
1505+
use_spec_decoding=metadata.use_spec_decoding,
1506+
is_spec_dec_tree=metadata.is_spec_dec_tree,
1507+
spec_decoding_generation_lengths=metadata.
1508+
spec_decoding_generation_lengths,
1509+
spec_decoding_position_offsets_for_cpp=metadata.
1510+
spec_decoding_position_offsets_for_cpp,
1511+
spec_decoding_packed_mask=metadata.spec_decoding_packed_mask,
1512+
spec_decoding_bl_tree_mask_offset=metadata.
1513+
spec_decoding_bl_tree_mask_offset,
1514+
spec_decoding_bl_tree_mask=metadata.spec_decoding_bl_tree_mask,
1515+
spec_bl_tree_first_sparse_mask_offset_kv=metadata.
1516+
spec_bl_tree_first_sparse_mask_offset_kv,
15371517
num_sparse_topk=metadata.num_sparse_topk,
15381518
flash_mla_tile_scheduler_metadata=metadata.
15391519
flash_mla_tile_scheduler_metadata,
@@ -1584,12 +1564,14 @@ def _run(
15841564
quant_mode=self.quant_mode,
15851565
q_scaling=self.q_scaling,
15861566
position_embedding_type=self.position_embedding_type,
1587-
rotary_embedding_dim=self.rotary_embedding_dim,
1588-
rotary_embedding_base=self.rotary_embedding_base,
1589-
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
1590-
rotary_embedding_scales=self.rotary_embedding_scales,
1591-
rotary_embedding_max_position_info=self.
1592-
rotary_embedding_max_position_info,
1567+
rope_dim=self.rope_dim,
1568+
rope_base=self.rope_base,
1569+
rope_scale_type=self.rope_scale_type,
1570+
rope_scale=self.rope_scale,
1571+
rope_short_m_scale=self.rope_short_m_scale,
1572+
rope_long_m_scale=self.rope_long_m_scale,
1573+
rope_max_positions=self.rope_max_positions,
1574+
rope_original_max_positions=self.rope_original_max_positions,
15931575
is_mla_enable=self.is_mla_enable,
15941576
q_lora_rank=self.q_lora_rank,
15951577
kv_lora_rank=self.kv_lora_rank,

0 commit comments

Comments
 (0)