@@ -945,63 +945,63 @@ def update_use_cudagraph(self, argument: bool):
945945 argument = self .use_cudagraph
946946
947947
948- class MobaAttentionConfig :
948+ class PlasAttentionConfig :
949949 def __init__ (
950950 self ,
951951 args ,
952952 ):
953- self .moba_encoder_top_k_left : int = None
954- self .moba_encoder_top_k_right : int = None
955- "The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]"
956- self .moba_decoder_top_k_left : int = None
957- self .moba_decoder_top_k_right : int = None
958- "The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]"
959- self .moba_use_encoder_seq_limit : int = None
960- "When the number of encdoer token is less than moba_use_encoder_seq_limit , it is not sparse"
961- self .moba_use_decoder_seq_limit : int = None
962- "When the number of decdoer token is less than moba_use_decoder_seq_limit , it is not sparse"
963- self .moba_block_size : int = 128
964- self .mlp_weight_name : str = "moba_mlp_weight .safetensors"
965- self .moba_max_seq_length : int = 128 * 1024
953+ self .plas_encoder_top_k_left : int = None
954+ self .plas_encoder_top_k_right : int = None
955+ "The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]"
956+ self .plas_decoder_top_k_left : int = None
957+ self .plas_decoder_top_k_right : int = None
958+ "The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]"
959+ self .plas_use_encoder_seq_limit : int = None
960+ "When the number of encdoer token is less than plas_use_encoder_seq_limit , it is not sparse"
961+ self .plas_use_decoder_seq_limit : int = None
962+ "When the number of decdoer token is less than plas_use_decoder_seq_limit , it is not sparse"
963+ self .plas_block_size : int = 128
964+ self .mlp_weight_name : str = "plas_attention_mlp_weight .safetensors"
965+ self .plas_max_seq_length : int = 128 * 1024
966966 if args is not None :
967967 for key , value in args .items ():
968968 if hasattr (self , key ):
969969 setattr (self , key , value )
970- if self .moba_use_encoder_seq_limit is None and self .moba_encoder_top_k_left is not None :
971- self .moba_use_encoder_seq_limit = self .moba_encoder_top_k_left * self .moba_block_size
972- if self .moba_use_decoder_seq_limit is None and self .moba_decoder_top_k_left is not None :
973- self .moba_use_decoder_seq_limit = self .moba_decoder_top_k_left * self .moba_block_size
970+ if self .plas_use_encoder_seq_limit is None and self .plas_encoder_top_k_left is not None :
971+ self .plas_use_encoder_seq_limit = self .plas_encoder_top_k_left * self .plas_block_size
972+ if self .plas_use_decoder_seq_limit is None and self .plas_decoder_top_k_left is not None :
973+ self .plas_use_decoder_seq_limit = self .plas_decoder_top_k_left * self .plas_block_size
974974 self .check_legality_parameters ()
975975
976976 def check_legality_parameters (
977977 self ,
978978 ) -> None :
979- if self .moba_encoder_top_k_left is not None :
980- assert self .moba_encoder_top_k_left > 0 , "moba_encoder_top_k_left must large than 0"
979+ if self .plas_encoder_top_k_left is not None :
980+ assert self .plas_encoder_top_k_left > 0 , "plas_encoder_top_k_left must large than 0"
981981
982- if self .moba_encoder_top_k_right is not None :
983- assert self .moba_encoder_top_k_right > 0 , "moba_encoder_top_k_right must large than 0"
982+ if self .plas_encoder_top_k_right is not None :
983+ assert self .plas_encoder_top_k_right > 0 , "plas_encoder_top_k_right must large than 0"
984984 assert (
985- self .moba_encoder_top_k_right >= self .moba_encoder_top_k_left
986- ), "moba_encoder_top_k_right must large than moba_encoder_top_k_left "
985+ self .plas_encoder_top_k_right >= self .plas_encoder_top_k_left
986+ ), "plas_encoder_top_k_right must large than plas_encoder_top_k_left "
987987
988- if self .moba_decoder_top_k_left is not None :
989- assert self .moba_decoder_top_k_left > 0 , "moba_decoder_top_k_left must large than 0"
988+ if self .plas_decoder_top_k_left is not None :
989+ assert self .plas_decoder_top_k_left > 0 , "plas_decoder_top_k_left must large than 0"
990990
991- if self .moba_decoder_top_k_right is not None :
992- assert self .moba_decoder_top_k_right > 0 , "moba_decoder_top_k_right must large than 0"
991+ if self .plas_decoder_top_k_right is not None :
992+ assert self .plas_decoder_top_k_right > 0 , "plas_decoder_top_k_right must large than 0"
993993 assert (
994- self .moba_decoder_top_k_right >= self .moba_decoder_top_k_left
995- ), "moba_decoder_top_k_right must large than moba_decoder_top_k_left "
994+ self .plas_decoder_top_k_right >= self .plas_decoder_top_k_left
995+ ), "plas_decoder_top_k_right must large than plas_decoder_top_k_left "
996996
997- if self .moba_use_encoder_seq_limit is not None and self .moba_encoder_top_k_left is not None :
998- assert self .moba_use_encoder_seq_limit >= self .moba_encoder_top_k_left * self .moba_block_size
999- if self .moba_use_decoder_seq_limit is not None and self .moba_decoder_top_k_left is not None :
1000- assert self .moba_use_decoder_seq_limit >= self .moba_decoder_top_k_left * self .moba_block_size
997+ if self .plas_use_encoder_seq_limit is not None and self .plas_encoder_top_k_left is not None :
998+ assert self .plas_use_encoder_seq_limit >= self .plas_encoder_top_k_left * self .plas_block_size
999+ if self .plas_use_decoder_seq_limit is not None and self .plas_decoder_top_k_left is not None :
1000+ assert self .plas_use_decoder_seq_limit >= self .plas_decoder_top_k_left * self .plas_block_size
10011001
10021002 def to_json_string (self ):
10031003 """
1004- Convert moba_attention_config to json string.
1004+ Convert plas_attention_config to json string.
10051005 """
10061006 return json .dumps ({key : value for key , value in self .__dict__ .items () if value is not None })
10071007
@@ -1396,7 +1396,7 @@ def __init__(
13961396 decoding_config : DecodingConfig = None ,
13971397 quant_config : QuantConfigBase = None ,
13981398 graph_opt_config : GraphOptimizationConfig = None ,
1399- moba_attention_config : MobaAttentionConfig = None ,
1399+ plas_attention_config : PlasAttentionConfig = None ,
14001400 speculative_config : SpeculativeConfig = None ,
14011401 tokenizer : str = None ,
14021402 max_model_len : int = 8192 ,
@@ -1427,7 +1427,7 @@ def __init__(
14271427 self .early_stop_config : Optional [EarlyStopConfig ] = early_stop_config
14281428 self .decoding_config : DecodingConfig = decoding_config # type: ignore
14291429 self .cache_config : CacheConfig = cache_config # type: ignore
1430- self .moba_attention_config : Optional [MobaAttentionConfig ] = moba_attention_config
1430+ self .plas_attention_config : Optional [PlasAttentionConfig ] = plas_attention_config
14311431 # Initialize cuda graph capture list
14321432 if self .graph_opt_config .cudagraph_capture_sizes is None :
14331433 self .graph_opt_config ._set_cudagraph_sizes (max_num_seqs = self .scheduler_config .max_num_seqs )
0 commit comments