Skip to content

Commit 3408ca9

Browse files
committed
fix r3
1 parent b522c16 commit 3408ca9

File tree

3 files changed

+8
-25
lines changed

3 files changed

+8
-25
lines changed

lightllm/common/basemodel/routing_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def routing_dtype_id_to_np(dtype_id: int):
1616
if dtype_id == 1:
17-
return np.int8
17+
return np.uint8
1818
elif dtype_id == 2:
1919
return np.int16
2020
return np.int32
@@ -39,8 +39,8 @@ def __init__(
3939
self.num_experts = num_experts
4040
self.kv_cache_size = kv_cache_size
4141

42-
self.dtype = torch.int8 if num_experts <= 127 else torch.int16
43-
dtype_bytes = 1 if self.dtype == torch.int8 else 2
42+
self.dtype = torch.uint8 if num_experts <= 255 else torch.int16
43+
dtype_bytes = 1 if self.dtype == torch.uint8 else 2
4444

4545
# Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory.
4646
# Written after forward() via flush_to_routing_buffer(), read on request finish.
@@ -57,7 +57,7 @@ def __init__(
5757
torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2)
5858
]
5959

60-
dtype_name = "int8" if self.dtype == torch.int8 else "int16"
60+
dtype_name = "uint8" if self.dtype == torch.uint8 else "int16"
6161
logger.info(
6262
f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
6363
f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, "
@@ -66,11 +66,11 @@ def __init__(
6666

6767
@property
6868
def np_dtype(self):
69-
return np.int8 if self.dtype == torch.int8 else np.int16
69+
return np.uint8 if self.dtype == torch.uint8 else np.int16
7070

7171
@property
7272
def dtype_id(self) -> int:
73-
return 1 if self.dtype == torch.int8 else 2
73+
return 1 if self.dtype == torch.uint8 else 2
7474

7575
def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
7676
num_tokens = topk_ids.shape[0]

lightllm/server/core/objs/req.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import math
33
import ctypes
4-
import base64
54
import numpy as np
65
import time
76
from .sampling_params import SamplingParams
@@ -289,7 +288,7 @@ def get_routing_metadata(self, num_moe_layers: int, topk: int, dtype_id: int = 1
289288
return {
290289
"shape": list(routing_data.shape),
291290
"dtype": str(routing_data.dtype),
292-
"data": base64.b64encode(routing_data.tobytes()).decode("ascii"),
291+
"data": list(routing_data.tobytes()),
293292
}
294293
except Exception as e:
295294
logger.warning(f"Failed to read routing data for req {self.request_id}: {e}")

test/test_api/test_r3.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import sys
22
import argparse
33
import requests
4-
import base64
54
import numpy as np
65

76

@@ -52,8 +51,7 @@ def test_routing_export(url: str = "http://localhost:8000"):
5251
shape = routing_info["shape"]
5352
dtype_str = routing_info["dtype"]
5453
dtype = np.dtype(dtype_str)
55-
data = base64.b64decode(routing_info["data"])
56-
routing_array = np.frombuffer(data, dtype=dtype).reshape(shape)
54+
routing_array = np.frombuffer(bytes(routing_info["data"]), dtype=dtype).reshape(shape)
5755

5856
print(f"\n{'=' * 50}")
5957
print("ROUTING CAPTURE SUCCESS!")
@@ -64,20 +62,6 @@ def test_routing_export(url: str = "http://localhost:8000"):
6462
print(f"Num tokens: {shape[1]}")
6563
print(f"Top-K: {shape[2]}")
6664

67-
# Verify dtype is int8 (for models with ≤127 experts) or int16
68-
if dtype_str not in ("int8", "int16"):
69-
print(f"\nERROR: Expected dtype int8 or int16, got {dtype_str}")
70-
print("This suggests dtype optimization is not working correctly.")
71-
return False
72-
print(f"\nDtype check PASSED: {dtype_str} (compact representation)")
73-
74-
# Compute payload size savings
75-
int32_size = np.prod(shape) * 4
76-
actual_size = len(data)
77-
savings = (1 - actual_size / int32_size) * 100
78-
print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)")
79-
80-
print(f"\nSample routing (first layer, first 5 tokens):")
8165
num_tokens_to_show = shape[1]
8266
for i in range(num_tokens_to_show):
8367
print(f" Token {i}: experts {routing_array[0, i, :].tolist()}")

0 commit comments

Comments
 (0)