|
| 1 | +"""R3/v1 binary deserializer for router-replay payloads. |
| 2 | +
|
| 3 | +Implements the inverse of the packed binary format produced by the tracing |
| 4 | +gateway's ``r3_serializer.serialize_r3``. See that module for the full |
| 5 | +header specification. |
| 6 | +
|
| 7 | +The main entry point is :func:`decompress_and_parse_r3`, which accepts the |
| 8 | +base64-encoded compressed blob returned by the gateway's |
| 9 | +``/v1/traces/pointwise?include_payloads=true`` endpoint and produces |
| 10 | +per-token routing matrices in the same ``List[Optional[str]]`` format used |
| 11 | +by the direct inference path (``DeploymentSampler.sample_with_tokens()``). |
| 12 | +""" |
| 13 | + |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +import base64 |
| 17 | +import struct |
| 18 | +from enum import IntEnum |
| 19 | +from typing import Any, Dict, List, Optional, Tuple |
| 20 | + |
| 21 | +import zstandard as zstd |
| 22 | + |
| 23 | +MAGIC = b"R3V1" |
| 24 | +HEADER_FORMAT = "<4sBBBBIIIIQ" |
| 25 | +HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes |
| 26 | +BITS_PER_BYTE = 8 |
| 27 | + |
| 28 | + |
| 29 | +class _SelectorMode(IntEnum): |
| 30 | + ALL = 0 |
| 31 | + SUFFIX = 1 |
| 32 | + BITMAP = 2 |
| 33 | + |
| 34 | + |
| 35 | +class _RoutingDtype(IntEnum): |
| 36 | + UINT8 = 1 |
| 37 | + UINT16 = 2 |
| 38 | + |
| 39 | + |
| 40 | +_SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode} |
| 41 | +_ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype} |
| 42 | + |
| 43 | + |
| 44 | +def _parse_header(raw: bytes) -> Dict[str, Any]: |
| 45 | + if len(raw) < HEADER_SIZE: |
| 46 | + raise ValueError( |
| 47 | + f"Payload too short for r3/v1 header: {len(raw)} < {HEADER_SIZE}" |
| 48 | + ) |
| 49 | + |
| 50 | + ( |
| 51 | + magic, |
| 52 | + version, |
| 53 | + selector_mode, |
| 54 | + routing_dtype, |
| 55 | + flags, |
| 56 | + total_token_count, |
| 57 | + replayed_token_count, |
| 58 | + replay_start_token, |
| 59 | + selector_byte_length, |
| 60 | + matrix_byte_length, |
| 61 | + ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) |
| 62 | + |
| 63 | + if magic != MAGIC: |
| 64 | + raise ValueError(f"Bad R3 magic: {magic!r}") |
| 65 | + if version != 1: |
| 66 | + raise ValueError(f"Unsupported R3 header version: {version}") |
| 67 | + |
| 68 | + return { |
| 69 | + "selector_mode": selector_mode, |
| 70 | + "routing_dtype": routing_dtype, |
| 71 | + "flags": flags, |
| 72 | + "total_token_count": total_token_count, |
| 73 | + "replayed_token_count": replayed_token_count, |
| 74 | + "replay_start_token": replay_start_token, |
| 75 | + "selector_byte_length": selector_byte_length, |
| 76 | + "matrix_byte_length": matrix_byte_length, |
| 77 | + } |
| 78 | + |
| 79 | + |
| 80 | +def _read_bitmap_positions( |
| 81 | + selector_bytes: bytes, total_token_count: int |
| 82 | +) -> List[int]: |
| 83 | + """Return sorted token indices where the bitmap bit is set.""" |
| 84 | + positions: List[int] = [] |
| 85 | + for i in range(total_token_count): |
| 86 | + byte_idx = i // BITS_PER_BYTE |
| 87 | + bit_idx = i % BITS_PER_BYTE |
| 88 | + if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1: |
| 89 | + positions.append(i) |
| 90 | + return positions |
| 91 | + |
| 92 | + |
| 93 | +def decompress_and_parse_r3( |
| 94 | + data_b64: str, |
| 95 | +) -> Tuple[List[Optional[str]], Dict[str, Any]]: |
| 96 | + """Decompress and unpack an R3/v1 payload into per-token routing matrices. |
| 97 | +
|
| 98 | + Args: |
| 99 | + data_b64: Base64-encoded zstd-compressed R3 binary blob, as returned |
| 100 | + by the tracing gateway in ``payloads.router_replay.data``. |
| 101 | +
|
| 102 | + Returns: |
| 103 | + A tuple of ``(routing_matrices, metadata)`` where: |
| 104 | +
|
| 105 | + - ``routing_matrices`` is a ``List[Optional[str]]`` of length |
| 106 | + ``total_token_count``. Each present position contains a |
| 107 | + base64-encoded routing matrix (matching the format returned by |
| 108 | + the direct inference path); absent positions are ``None``. |
| 109 | + - ``metadata`` is a dict with keys ``routing_dtype``, |
| 110 | + ``selector_mode``, ``total_token_count``, ``replayed_token_count``, |
| 111 | + ``replay_start_token``. |
| 112 | + """ |
| 113 | + compressed = base64.b64decode(data_b64) |
| 114 | + |
| 115 | + # ZstdCompressor.compress() embeds the uncompressed size in the frame |
| 116 | + # header by default, so the library can auto-allocate the output buffer. |
| 117 | + decompressor = zstd.ZstdDecompressor() |
| 118 | + raw = decompressor.decompress(compressed) |
| 119 | + |
| 120 | + header = _parse_header(raw) |
| 121 | + |
| 122 | + selector_mode = header["selector_mode"] |
| 123 | + routing_dtype = header["routing_dtype"] |
| 124 | + total_token_count = header["total_token_count"] |
| 125 | + replayed_token_count = header["replayed_token_count"] |
| 126 | + replay_start_token = header["replay_start_token"] |
| 127 | + selector_byte_length = header["selector_byte_length"] |
| 128 | + matrix_byte_length = header["matrix_byte_length"] |
| 129 | + |
| 130 | + metadata: Dict[str, Any] = { |
| 131 | + "routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)), |
| 132 | + "selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)), |
| 133 | + "total_token_count": total_token_count, |
| 134 | + "replayed_token_count": replayed_token_count, |
| 135 | + "replay_start_token": replay_start_token, |
| 136 | + } |
| 137 | + |
| 138 | + if replayed_token_count == 0: |
| 139 | + return [None] * total_token_count, metadata |
| 140 | + |
| 141 | + # Per-token matrix byte size is implicit in the payload: all replayed |
| 142 | + # tokens share the same matrix length, so we can recover it from the |
| 143 | + # matrix section total length divided by the replayed-token count. |
| 144 | + if matrix_byte_length % replayed_token_count != 0: |
| 145 | + raise ValueError( |
| 146 | + f"matrix_byte_length ({matrix_byte_length}) is not a multiple of " |
| 147 | + f"replayed_token_count ({replayed_token_count}); cannot split " |
| 148 | + "into per-token matrices" |
| 149 | + ) |
| 150 | + matrix_elem_size = matrix_byte_length // replayed_token_count |
| 151 | + |
| 152 | + body = raw[HEADER_SIZE:] |
| 153 | + expected_body_length = selector_byte_length + matrix_byte_length |
| 154 | + if len(body) < expected_body_length: |
| 155 | + raise ValueError( |
| 156 | + f"Payload body too short for selector and matrix sections: " |
| 157 | + f"{len(body)} < {expected_body_length}" |
| 158 | + ) |
| 159 | + |
| 160 | + selector_bytes = body[:selector_byte_length] |
| 161 | + matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length] |
| 162 | + |
| 163 | + if selector_mode == _SelectorMode.ALL: |
| 164 | + replayed_positions = list(range(total_token_count)) |
| 165 | + elif selector_mode == _SelectorMode.SUFFIX: |
| 166 | + replayed_positions = list( |
| 167 | + range(replay_start_token, replay_start_token + replayed_token_count) |
| 168 | + ) |
| 169 | + elif selector_mode == _SelectorMode.BITMAP: |
| 170 | + replayed_positions = _read_bitmap_positions(selector_bytes, total_token_count) |
| 171 | + else: |
| 172 | + raise ValueError(f"Unknown selector_mode: {selector_mode}") |
| 173 | + |
| 174 | + if len(replayed_positions) != replayed_token_count: |
| 175 | + raise ValueError( |
| 176 | + f"Selector produced {len(replayed_positions)} replayed positions, " |
| 177 | + f"but header replayed_token_count is {replayed_token_count}" |
| 178 | + ) |
| 179 | + |
| 180 | + # Split matrix bytes into per-token chunks and base64-encode each one |
| 181 | + matrices: List[Optional[str]] = [None] * total_token_count |
| 182 | + for idx, pos in enumerate(replayed_positions): |
| 183 | + start = idx * matrix_elem_size |
| 184 | + end = start + matrix_elem_size |
| 185 | + matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") |
| 186 | + |
| 187 | + return matrices, metadata |
0 commit comments