@@ -157,22 +157,6 @@ def effective_workspace(self) -> Optional[torch.Tensor]:
157157 """Attention-kernel workspace, switching to the CUDA-graph copy under capture."""
158158 return self .cuda_graph_workspace if self .is_cuda_graph else self .workspace
159159
160- @property
161- def helix_tensor_params (self ) -> List [Optional [torch .Tensor ]]:
162- """``[helix_position_offsets, helix_is_inactive_rank]`` — the positional
163- helix tensor list expected by the C++ attention op."""
164- return [self .helix_position_offsets , self .helix_is_inactive_rank ]
165-
166- @property
167- def spec_decoding_bool_params (self ) -> List [bool ]:
168- """``[is_spec_decoding_enabled, use_spec_decoding, is_spec_dec_tree]`` —
169- the positional bool list expected by the C++ attention op."""
170- return [
171- self .is_spec_decoding_enabled ,
172- self .use_spec_decoding ,
173- self .is_spec_dec_tree ,
174- ]
175-
176160 @property
177161 def spec_decoding_position_offsets_for_cpp (self ) -> Optional [torch .Tensor ]:
178162 """``spec_decoding_position_offsets`` reshaped to the 2D layout the C++
@@ -1051,22 +1035,6 @@ def generate_spec_decoding_generation_length(self, runtime_draft_len):
10511035 def is_sm_version_trtllm_gen_kernel (self , sm ):
10521036 return not (sm < 100 or sm in [120 , 121 ])
10531037
1054- @property
1055- def spec_decoding_tensor_params (self ) -> List [Optional [torch .Tensor ]]:
1056- """Positional spec-decoding tensor list for the C++ attention op.
1057- Includes three Blackwell-tree mask tensors on SM versions that take
1058- the trtllm-gen kernel."""
1059- params = [
1060- self .spec_decoding_generation_lengths ,
1061- self .spec_decoding_position_offsets_for_cpp ,
1062- self .spec_decoding_packed_mask ,
1063- ]
1064- if self .is_sm_version_trtllm_gen_kernel (sm = get_sm_version ()):
1065- params .append (self .spec_decoding_bl_tree_mask_offset )
1066- params .append (self .spec_decoding_bl_tree_mask )
1067- params .append (self .spec_bl_tree_first_sparse_mask_offset_kv )
1068- return params
1069-
10701038
10711039class TrtllmAttention (AttentionBackend [TrtllmAttentionMetadata ]):
10721040
@@ -1332,35 +1300,36 @@ def create_output(self, q, *, is_quantize_output: bool,
13321300 ]
13331301
13341302 @property
1335- def rotary_embedding_dim (self ) -> int :
1303+ def rope_dim (self ) -> int :
13361304 return self .rope_params .dim
13371305
13381306 @property
1339- def rotary_embedding_base (self ) -> float :
1307+ def rope_base (self ) -> float :
13401308 return self .rope_params .theta
13411309
13421310 @property
1343- def rotary_embedding_scale_type (self ) -> int :
1311+ def rope_scale_type (self ) -> int :
13441312 return int (self .rope_params .scale_type )
13451313
13461314 @property
1347- def rotary_embedding_scales (self ) -> List [float ]:
1348- """``[scale, short_m_scale, long_m_scale]`` — the positional RoPE-scale
1349- list expected by the C++ attention op."""
1350- return [
1351- self .rope_params .scale ,
1352- self .rope_params .short_m_scale ,
1353- self .rope_params .long_m_scale ,
1354- ]
1315+ def rope_scale (self ) -> float :
1316+ return self .rope_params .scale
13551317
13561318 @property
1357- def rotary_embedding_max_position_info (self ) -> List [int ]:
1358- """``[max_positions, original_max_positions]`` — the positional
1359- RoPE-positions list expected by the C++ attention op."""
1360- return [
1361- self .rope_params .max_positions ,
1362- self .rope_params .original_max_positions ,
1363- ]
1319+ def rope_short_m_scale (self ) -> float :
1320+ return self .rope_params .short_m_scale
1321+
1322+ @property
1323+ def rope_long_m_scale (self ) -> float :
1324+ return self .rope_params .long_m_scale
1325+
1326+ @property
1327+ def rope_max_positions (self ) -> int :
1328+ return self .rope_params .max_positions
1329+
1330+ @property
1331+ def rope_original_max_positions (self ) -> int :
1332+ return self .rope_params .original_max_positions
13641333
13651334 @property
13661335 def skip_softmax_threshold_scale_factor_prefill (self ) -> Optional [float ]:
@@ -1530,10 +1499,21 @@ def _run(
15301499 max_num_requests = metadata .max_num_requests ,
15311500 beam_width = metadata .beam_width ,
15321501 use_paged_context_fmha = metadata .use_paged_context_fmha ,
1533- helix_tensor_params = metadata .helix_tensor_params ,
1534- spec_decoding_bool_params = metadata .spec_decoding_bool_params ,
1535- spec_decoding_tensor_params = metadata .
1536- spec_decoding_tensor_params ,
1502+ helix_position_offsets = metadata .helix_position_offsets ,
1503+ helix_is_inactive_rank = metadata .helix_is_inactive_rank ,
1504+ is_spec_decoding_enabled = metadata .is_spec_decoding_enabled ,
1505+ use_spec_decoding = metadata .use_spec_decoding ,
1506+ is_spec_dec_tree = metadata .is_spec_dec_tree ,
1507+ spec_decoding_generation_lengths = metadata .
1508+ spec_decoding_generation_lengths ,
1509+ spec_decoding_position_offsets_for_cpp = metadata .
1510+ spec_decoding_position_offsets_for_cpp ,
1511+ spec_decoding_packed_mask = metadata .spec_decoding_packed_mask ,
1512+ spec_decoding_bl_tree_mask_offset = metadata .
1513+ spec_decoding_bl_tree_mask_offset ,
1514+ spec_decoding_bl_tree_mask = metadata .spec_decoding_bl_tree_mask ,
1515+ spec_bl_tree_first_sparse_mask_offset_kv = metadata .
1516+ spec_bl_tree_first_sparse_mask_offset_kv ,
15371517 num_sparse_topk = metadata .num_sparse_topk ,
15381518 flash_mla_tile_scheduler_metadata = metadata .
15391519 flash_mla_tile_scheduler_metadata ,
@@ -1584,12 +1564,14 @@ def _run(
15841564 quant_mode = self .quant_mode ,
15851565 q_scaling = self .q_scaling ,
15861566 position_embedding_type = self .position_embedding_type ,
1587- rotary_embedding_dim = self .rotary_embedding_dim ,
1588- rotary_embedding_base = self .rotary_embedding_base ,
1589- rotary_embedding_scale_type = self .rotary_embedding_scale_type ,
1590- rotary_embedding_scales = self .rotary_embedding_scales ,
1591- rotary_embedding_max_position_info = self .
1592- rotary_embedding_max_position_info ,
1567+ rope_dim = self .rope_dim ,
1568+ rope_base = self .rope_base ,
1569+ rope_scale_type = self .rope_scale_type ,
1570+ rope_scale = self .rope_scale ,
1571+ rope_short_m_scale = self .rope_short_m_scale ,
1572+ rope_long_m_scale = self .rope_long_m_scale ,
1573+ rope_max_positions = self .rope_max_positions ,
1574+ rope_original_max_positions = self .rope_original_max_positions ,
15931575 is_mla_enable = self .is_mla_enable ,
15941576 q_lora_rank = self .q_lora_rank ,
15951577 kv_lora_rank = self .kv_lora_rank ,
0 commit comments