@@ -5,35 +5,37 @@ use serde_json::{json, Value};
55
66#[ derive( Clone , Debug , PartialEq , Eq ) ]
77struct RoutedExpertsPayload {
8+ start : usize ,
89 seq_len : usize ,
910 layers : usize ,
1011 topk : usize ,
1112 data : Vec < u8 > ,
1213}
1314
1415impl RoutedExpertsPayload {
15- fn suffix_rows ( & self , row_start : usize ) -> Result < Self , String > {
16- if row_start > self . seq_len {
16+ fn suffix_rows ( & self , row_count : usize ) -> Result < Self , String > {
17+ if row_count > self . seq_len {
1718 return Err ( format ! (
18- "decode routed_experts has {} rows, expected at least {row_start }" ,
19+ "decode routed_experts has {} rows, expected at least {row_count }" ,
1920 self . seq_len
2021 ) ) ;
2122 }
2223 let row_size = self . layers * self . topk ;
23- let byte_start = row_start * row_size;
24+ let byte_start = row_count * row_size;
2425 let data = self
2526 . data
2627 . get ( byte_start..)
2728 . ok_or_else ( || {
2829 format ! (
29- "decode routed_experts has {} rows, expected at least {row_start }" ,
30+ "decode routed_experts has {} rows, expected at least {row_count }" ,
3031 self . seq_len
3132 )
3233 } ) ?
3334 . to_vec ( ) ;
3435
3536 Ok ( Self {
36- seq_len : self . seq_len - row_start,
37+ start : self . start + row_count,
38+ seq_len : self . seq_len - row_count,
3739 layers : self . layers ,
3840 topk : self . topk ,
3941 data,
@@ -47,12 +49,12 @@ impl RoutedExpertsPayload {
4749 self . seq_len, self . layers, self . topk, other. seq_len, other. layers, other. topk,
4850 ) ) ;
4951 }
50-
5152 let mut data = Vec :: with_capacity ( self . data . len ( ) + other. data . len ( ) ) ;
5253 data. extend_from_slice ( & self . data ) ;
5354 data. extend_from_slice ( & other. data ) ;
5455
5556 Ok ( Self {
57+ start : self . start ,
5658 seq_len : self . seq_len + other. seq_len ,
5759 layers : self . layers ,
5860 topk : self . topk ,
@@ -129,6 +131,12 @@ fn decode_routed_experts_value(value: &Value, name: &str) -> Result<RoutedExpert
129131 . get ( "data" )
130132 . and_then ( Value :: as_str)
131133 . ok_or_else ( || format ! ( "{name} data must be a base64 string" ) ) ?;
134+ let start = payload
135+ . get ( "start" )
136+ . and_then ( Value :: as_u64)
137+ . ok_or_else ( || format ! ( "{name} start must be a non-negative integer" ) ) ?;
138+ let start =
139+ usize:: try_from ( start) . map_err ( |error| format ! ( "{name} start parse failed: {error}" ) ) ?;
132140 let ( seq_len, layers, topk) = parse_shape ( payload. get ( "shape" ) , name) ?;
133141 let bytes = STANDARD
134142 . decode ( data_payload)
@@ -145,6 +153,7 @@ fn decode_routed_experts_value(value: &Value, name: &str) -> Result<RoutedExpert
145153 }
146154
147155 Ok ( RoutedExpertsPayload {
156+ start,
148157 seq_len,
149158 layers,
150159 topk,
@@ -177,6 +186,7 @@ fn encode_routed_experts_payload(payload: &RoutedExpertsPayload) -> Value {
177186 json ! ( {
178187 "data" : STANDARD . encode( & payload. data) ,
179188 "shape" : [ payload. seq_len, payload. layers, payload. topk] ,
189+ "start" : payload. start,
180190 } )
181191}
182192
@@ -187,6 +197,7 @@ mod tests {
187197
188198 fn uint8_payload ( seq_len : usize , layers : usize , topk : usize , data : & [ u8 ] ) -> Value {
189199 let payload = RoutedExpertsPayload {
200+ start : 0 ,
190201 seq_len,
191202 layers,
192203 topk,
0 commit comments