1414
1515def routing_dtype_id_to_np (dtype_id : int ):
1616 if dtype_id == 1 :
17- return np .int8
17+ return np .uint8
1818 elif dtype_id == 2 :
1919 return np .int16
2020 return np .int32
@@ -39,8 +39,8 @@ def __init__(
3939 self .num_experts = num_experts
4040 self .kv_cache_size = kv_cache_size
4141
42- self .dtype = torch .int8 if num_experts <= 127 else torch .int16
43- dtype_bytes = 1 if self .dtype == torch .int8 else 2
42+ self .dtype = torch .uint8 if num_experts <= 255 else torch .int16
43+ dtype_bytes = 1 if self .dtype == torch .uint8 else 2
4444
4545 # Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory.
4646 # Written after forward() via flush_to_routing_buffer(), read on request finish.
@@ -57,7 +57,7 @@ def __init__(
5757 torch .zeros ((max_capture_tokens , num_moe_layers , topk ), dtype = self .dtype , device = "cuda" ) for _ in range (2 )
5858 ]
5959
60- dtype_name = "int8 " if self .dtype == torch .int8 else "int16"
60+ dtype_name = "uint8 " if self .dtype == torch .uint8 else "int16"
6161 logger .info (
6262 f"RoutingCaptureManager initialized: { num_moe_layers } MoE layers, topk={ topk } , "
6363 f"routing_buffer(cpu)={ routing_buffer_size / 1024 / 1024 :.2f} MB, "
@@ -66,11 +66,11 @@ def __init__(
6666
6767 @property
6868 def np_dtype (self ):
69- return np .int8 if self .dtype == torch .int8 else np .int16
69+ return np .uint8 if self .dtype == torch .uint8 else np .int16
7070
7171 @property
7272 def dtype_id (self ) -> int :
73- return 1 if self .dtype == torch .int8 else 2
73+ return 1 if self .dtype == torch .uint8 else 2
7474
7575 def capture (self , moe_layer_index : int , topk_ids : torch .Tensor , microbatch_index : int = 0 ) -> None :
7676 num_tokens = topk_ids .shape [0 ]
0 commit comments