Skip to content

Commit 4249247

Browse files
authored
Merge pull request #37 from PrimeIntellect-ai/feat/routed-experts-delta-replay
Merge routed experts deltas with start offsets
2 parents 8f6524e + af649b5 commit 4249247

1 file changed

Lines changed: 18 additions & 7 deletions

File tree

src/routers/http/routed_experts_merge.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,37 @@ use serde_json::{json, Value};
55

66
#[derive(Clone, Debug, PartialEq, Eq)]
77
struct RoutedExpertsPayload {
8+
start: usize,
89
seq_len: usize,
910
layers: usize,
1011
topk: usize,
1112
data: Vec<u8>,
1213
}
1314

1415
impl RoutedExpertsPayload {
15-
fn suffix_rows(&self, row_start: usize) -> Result<Self, String> {
16-
if row_start > self.seq_len {
16+
fn suffix_rows(&self, row_count: usize) -> Result<Self, String> {
17+
if row_count > self.seq_len {
1718
return Err(format!(
18-
"decode routed_experts has {} rows, expected at least {row_start}",
19+
"decode routed_experts has {} rows, expected at least {row_count}",
1920
self.seq_len
2021
));
2122
}
2223
let row_size = self.layers * self.topk;
23-
let byte_start = row_start * row_size;
24+
let byte_start = row_count * row_size;
2425
let data = self
2526
.data
2627
.get(byte_start..)
2728
.ok_or_else(|| {
2829
format!(
29-
"decode routed_experts has {} rows, expected at least {row_start}",
30+
"decode routed_experts has {} rows, expected at least {row_count}",
3031
self.seq_len
3132
)
3233
})?
3334
.to_vec();
3435

3536
Ok(Self {
36-
seq_len: self.seq_len - row_start,
37+
start: self.start + row_count,
38+
seq_len: self.seq_len - row_count,
3739
layers: self.layers,
3840
topk: self.topk,
3941
data,
@@ -47,12 +49,12 @@ impl RoutedExpertsPayload {
4749
self.seq_len, self.layers, self.topk, other.seq_len, other.layers, other.topk,
4850
));
4951
}
50-
5152
let mut data = Vec::with_capacity(self.data.len() + other.data.len());
5253
data.extend_from_slice(&self.data);
5354
data.extend_from_slice(&other.data);
5455

5556
Ok(Self {
57+
start: self.start,
5658
seq_len: self.seq_len + other.seq_len,
5759
layers: self.layers,
5860
topk: self.topk,
@@ -129,6 +131,12 @@ fn decode_routed_experts_value(value: &Value, name: &str) -> Result<RoutedExpert
129131
.get("data")
130132
.and_then(Value::as_str)
131133
.ok_or_else(|| format!("{name} data must be a base64 string"))?;
134+
let start = payload
135+
.get("start")
136+
.and_then(Value::as_u64)
137+
.ok_or_else(|| format!("{name} start must be a non-negative integer"))?;
138+
let start =
139+
usize::try_from(start).map_err(|error| format!("{name} start parse failed: {error}"))?;
132140
let (seq_len, layers, topk) = parse_shape(payload.get("shape"), name)?;
133141
let bytes = STANDARD
134142
.decode(data_payload)
@@ -145,6 +153,7 @@ fn decode_routed_experts_value(value: &Value, name: &str) -> Result<RoutedExpert
145153
}
146154

147155
Ok(RoutedExpertsPayload {
156+
start,
148157
seq_len,
149158
layers,
150159
topk,
@@ -177,6 +186,7 @@ fn encode_routed_experts_payload(payload: &RoutedExpertsPayload) -> Value {
177186
json!({
178187
"data": STANDARD.encode(&payload.data),
179188
"shape": [payload.seq_len, payload.layers, payload.topk],
189+
"start": payload.start,
180190
})
181191
}
182192

@@ -187,6 +197,7 @@ mod tests {
187197

188198
fn uint8_payload(seq_len: usize, layers: usize, topk: usize, data: &[u8]) -> Value {
189199
let payload = RoutedExpertsPayload {
200+
start: 0,
190201
seq_len,
191202
layers,
192203
topk,

0 commit comments

Comments
 (0)