|
| 1 | +"""R3/v1 binary deserializer for router-replay payloads. |
| 2 | +
|
| 3 | +Implements the inverse of the packed binary format produced by the tracing |
| 4 | +gateway's ``r3_serializer.serialize_r3``. See that module for the full |
| 5 | +header specification. |
| 6 | +
|
| 7 | +The main entry point is :func:`decompress_and_parse_r3`, which accepts the |
| 8 | +base64-encoded compressed blob returned by the gateway's |
| 9 | +``/v1/traces/pointwise?include_payloads=true`` endpoint and produces |
| 10 | +per-token routing matrices in the same ``List[Optional[str]]`` format used |
| 11 | +by the direct inference path (``DeploymentSampler.sample_with_tokens()``). |
| 12 | +""" |
| 13 | + |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +import base64 |
| 17 | +import logging |
| 18 | +import math |
| 19 | +import struct |
| 20 | +from enum import IntEnum |
| 21 | +from typing import Any, Dict, List, Optional, Tuple |
| 22 | + |
| 23 | +import zstandard as zstd |
| 24 | + |
| 25 | +logger = logging.getLogger(__name__) |
| 26 | + |
| 27 | +MAGIC = b"R3V1" |
| 28 | +HEADER_FORMAT = "<4sBBBBIIHHIIQ" |
| 29 | +HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 36 bytes |
| 30 | + |
| 31 | + |
| 32 | +class _SelectorMode(IntEnum): |
| 33 | + ALL = 0 |
| 34 | + SUFFIX = 1 |
| 35 | + BITMAP = 2 |
| 36 | + |
| 37 | + |
| 38 | +class _RoutingDtype(IntEnum): |
| 39 | + UINT8 = 1 |
| 40 | + UINT16 = 2 |
| 41 | + |
| 42 | + @property |
| 43 | + def byte_width(self) -> int: |
| 44 | + return self.value |
| 45 | + |
| 46 | + |
| 47 | +_SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode} |
| 48 | +_ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype} |
| 49 | + |
| 50 | + |
| 51 | +def _parse_header(raw: bytes) -> Dict[str, Any]: |
| 52 | + if len(raw) < HEADER_SIZE: |
| 53 | + raise ValueError( |
| 54 | + f"Payload too short for r3/v1 header: {len(raw)} < {HEADER_SIZE}" |
| 55 | + ) |
| 56 | + |
| 57 | + ( |
| 58 | + magic, |
| 59 | + version, |
| 60 | + selector_mode, |
| 61 | + routing_dtype, |
| 62 | + flags, |
| 63 | + total_token_count, |
| 64 | + replayed_token_count, |
| 65 | + num_moe_layers, |
| 66 | + top_k, |
| 67 | + replay_start_token, |
| 68 | + selector_byte_length, |
| 69 | + matrix_byte_length, |
| 70 | + ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) |
| 71 | + |
| 72 | + if magic != MAGIC: |
| 73 | + raise ValueError(f"Bad R3 magic: {magic!r}") |
| 74 | + if version != 1: |
| 75 | + raise ValueError(f"Unsupported R3 header version: {version}") |
| 76 | + |
| 77 | + return { |
| 78 | + "selector_mode": selector_mode, |
| 79 | + "routing_dtype": routing_dtype, |
| 80 | + "flags": flags, |
| 81 | + "total_token_count": total_token_count, |
| 82 | + "replayed_token_count": replayed_token_count, |
| 83 | + "num_moe_layers": num_moe_layers, |
| 84 | + "top_k": top_k, |
| 85 | + "replay_start_token": replay_start_token, |
| 86 | + "selector_byte_length": selector_byte_length, |
| 87 | + "matrix_byte_length": matrix_byte_length, |
| 88 | + } |
| 89 | + |
| 90 | + |
| 91 | +def _read_bitmap_positions( |
| 92 | + selector_bytes: bytes, total_token_count: int |
| 93 | +) -> List[int]: |
| 94 | + """Return sorted token indices where the bitmap bit is set.""" |
| 95 | + positions: List[int] = [] |
| 96 | + for i in range(total_token_count): |
| 97 | + byte_idx = i >> 3 |
| 98 | + bit_idx = i & 7 |
| 99 | + if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1: |
| 100 | + positions.append(i) |
| 101 | + return positions |
| 102 | + |
| 103 | + |
| 104 | +def decompress_and_parse_r3( |
| 105 | + data_b64: str, |
| 106 | +) -> Tuple[List[Optional[str]], Dict[str, Any]]: |
| 107 | + """Decompress and unpack an R3/v1 payload into per-token routing matrices. |
| 108 | +
|
| 109 | + Args: |
| 110 | + data_b64: Base64-encoded zstd-compressed R3 binary blob, as returned |
| 111 | + by the tracing gateway in ``payloads.router_replay.data``. |
| 112 | +
|
| 113 | + Returns: |
| 114 | + A tuple of ``(routing_matrices, metadata)`` where: |
| 115 | +
|
| 116 | + - ``routing_matrices`` is a ``List[Optional[str]]`` of length |
| 117 | + ``total_token_count``. Each present position contains a |
| 118 | + base64-encoded routing matrix (matching the format returned by |
| 119 | + the direct inference path); absent positions are ``None``. |
| 120 | + - ``metadata`` is a dict with keys ``num_moe_layers``, ``top_k``, |
| 121 | + ``routing_dtype``, ``selector_mode``, ``total_token_count``, |
| 122 | + ``replayed_token_count``, ``replay_start_token``. |
| 123 | + """ |
| 124 | + compressed = base64.b64decode(data_b64) |
| 125 | + |
| 126 | + decompressor = zstd.ZstdDecompressor() |
| 127 | + raw = decompressor.decompress(compressed, max_output_size=len(compressed) * 20) |
| 128 | + |
| 129 | + header = _parse_header(raw) |
| 130 | + |
| 131 | + selector_mode = header["selector_mode"] |
| 132 | + routing_dtype = header["routing_dtype"] |
| 133 | + total_token_count = header["total_token_count"] |
| 134 | + replayed_token_count = header["replayed_token_count"] |
| 135 | + num_moe_layers = header["num_moe_layers"] |
| 136 | + top_k = header["top_k"] |
| 137 | + replay_start_token = header["replay_start_token"] |
| 138 | + selector_byte_length = header["selector_byte_length"] |
| 139 | + matrix_byte_length = header["matrix_byte_length"] |
| 140 | + |
| 141 | + dtype_byte_width = _RoutingDtype(routing_dtype).byte_width |
| 142 | + matrix_elem_size = num_moe_layers * top_k * dtype_byte_width |
| 143 | + |
| 144 | + body = raw[HEADER_SIZE:] |
| 145 | + selector_bytes = body[:selector_byte_length] |
| 146 | + matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length] |
| 147 | + |
| 148 | + if matrix_elem_size == 0: |
| 149 | + replayed_positions: List[int] = [] |
| 150 | + elif selector_mode == _SelectorMode.ALL: |
| 151 | + replayed_positions = list(range(total_token_count)) |
| 152 | + elif selector_mode == _SelectorMode.SUFFIX: |
| 153 | + replayed_positions = list( |
| 154 | + range(replay_start_token, replay_start_token + replayed_token_count) |
| 155 | + ) |
| 156 | + elif selector_mode == _SelectorMode.BITMAP: |
| 157 | + replayed_positions = _read_bitmap_positions(selector_bytes, total_token_count) |
| 158 | + else: |
| 159 | + raise ValueError(f"Unknown selector_mode: {selector_mode}") |
| 160 | + |
| 161 | + # Split matrix bytes into per-token chunks and base64-encode each one |
| 162 | + matrices: List[Optional[str]] = [None] * total_token_count |
| 163 | + for idx, pos in enumerate(replayed_positions): |
| 164 | + start = idx * matrix_elem_size |
| 165 | + end = start + matrix_elem_size |
| 166 | + if end > len(matrix_bytes): |
| 167 | + logger.warning( |
| 168 | + "R3 matrix data truncated at token %d (position %d): " |
| 169 | + "expected %d bytes but only %d remaining", |
| 170 | + idx, pos, matrix_elem_size, len(matrix_bytes) - start, |
| 171 | + ) |
| 172 | + break |
| 173 | + matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") |
| 174 | + |
| 175 | + metadata: Dict[str, Any] = { |
| 176 | + "num_moe_layers": num_moe_layers, |
| 177 | + "top_k": top_k, |
| 178 | + "routing_dtype": _ROUTING_DTYPE_NAMES.get( |
| 179 | + _RoutingDtype(routing_dtype), str(routing_dtype) |
| 180 | + ), |
| 181 | + "selector_mode": _SELECTOR_MODE_NAMES.get( |
| 182 | + _SelectorMode(selector_mode), str(selector_mode) |
| 183 | + ), |
| 184 | + "total_token_count": total_token_count, |
| 185 | + "replayed_token_count": replayed_token_count, |
| 186 | + "replay_start_token": replay_start_token, |
| 187 | + } |
| 188 | + |
| 189 | + return matrices, metadata |
0 commit comments