@@ -1006,6 +1006,191 @@ pub struct WireTokenAgreementResult {
10061006
10071007fn default_ta_backend ( ) -> String { "stub" . to_string ( ) }
10081008
1009+ // ═══════════════════════════════════════════════════════════════════════════
1010+ // D0.3 — WireSweep: the streaming cross-product sweep surface (Phase 0 stub)
1011+ //
1012+ // Per .claude/plans/codec-sweep-via-lab-infra-v1.md § D0.3:
1013+ // client POSTs a single `WireSweepRequest` containing a `WireSweepGrid`
1014+ // (cross-product of codec-params axes). Server enumerates the grid,
1015+ // calls D0.1 (calibrate) + D0.2 (token-agreement) per grid point, streams
1016+ // each pair as a `WireSweepResult` via SSE / gRPC stream, and optionally
1017+ // appends each row to a Lance fragment.
1018+ //
1019+ // Phase 0: DTOs + grid enumerator + in-memory result vector land NOW.
1020+ // Phase 3 D3.1: SSE streaming handler + Lance fragment writer land.
1021+ //
1022+ // The enumerator is the load-bearing piece: it turns a sweep YAML file
1023+ // into N `WireCodecParams` candidates, each hashed to a JIT kernel via
1024+ // `CodecParams::kernel_signature()`. First hit compiles (~5–20 ms); all
1025+ // subsequent candidates with the same signature hit the cache. The grid
1026+ // is the input surface; kernel_signature is the cache key. That's the
1027+ // full operational loop.
1028+ // ═══════════════════════════════════════════════════════════════════════════
1029+
1030+ /// Which measurements the sweep should request per grid point.
1031+ /// Serialized as lowercase snake_case strings in the YAML / JSON request.
1032+ #[ derive( Debug , Clone , Copy , PartialEq , Eq , Serialize , Deserialize ) ]
1033+ #[ serde( rename_all = "snake_case" ) ]
1034+ pub enum WireMeasure {
1035+ /// Held-out reconstruction error (L2 relative).
1036+ ReconstructionErrorHeldOut ,
1037+ /// Held-out reconstruction ICC (Spearman rho against F32 ground truth).
1038+ ReconstructionIccHeldOut ,
1039+ /// Top-1 token agreement vs Passthrough baseline (I11 cert gate, D0.2).
1040+ TokenAgreementTop1 ,
1041+ /// Top-5 token agreement vs Passthrough baseline.
1042+ TokenAgreementTop5 ,
1043+ /// Per-layer MSE between candidate and reference hidden states.
1044+ PerLayerMse ,
1045+ }
1046+
1047+ /// The cross-product sweep grid. Each field is a vec; the enumerated grid
1048+ /// is the Cartesian product.
1049+ ///
1050+ /// Cardinality = |subspaces| × |centroids| × |residual_depths| × |rotations|
1051+ /// × |distances| × |lane_widths|. Clients SHOULD keep the product ≤ a few
1052+ /// hundred to fit in one JIT kernel cache warm-up round.
1053+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
1054+ pub struct WireSweepGrid {
1055+ #[ serde( default = "default_subspaces_axis" ) ]
1056+ pub subspaces : Vec < u32 > ,
1057+ #[ serde( default = "default_centroids_axis" ) ]
1058+ pub centroids : Vec < u32 > ,
1059+ #[ serde( default = "default_residual_depths_axis" ) ]
1060+ pub residual_depths : Vec < u8 > ,
1061+ #[ serde( default = "default_rotations_axis" ) ]
1062+ pub rotations : Vec < WireRotation > ,
1063+ #[ serde( default = "default_distances_axis" ) ]
1064+ pub distances : Vec < WireDistance > ,
1065+ #[ serde( default = "default_lane_widths_axis" ) ]
1066+ pub lane_widths : Vec < WireLaneWidth > ,
1067+ #[ serde( default = "default_residual_centroids" ) ]
1068+ pub residual_centroids : u32 ,
1069+ #[ serde( default = "default_calibration_rows" ) ]
1070+ pub calibration_rows : u32 ,
1071+ #[ serde( default ) ]
1072+ pub measurement_rows : u32 ,
1073+ #[ serde( default = "default_seed" ) ]
1074+ pub seed : u64 ,
1075+ }
1076+
1077+ fn default_subspaces_axis ( ) -> Vec < u32 > { vec ! [ 6 ] }
1078+ fn default_centroids_axis ( ) -> Vec < u32 > { vec ! [ 256 ] }
1079+ fn default_residual_depths_axis ( ) -> Vec < u8 > { vec ! [ 0 ] }
1080+ fn default_rotations_axis ( ) -> Vec < WireRotation > { vec ! [ WireRotation :: Identity ] }
1081+ fn default_distances_axis ( ) -> Vec < WireDistance > { vec ! [ WireDistance :: AdcU8 ] }
1082+ fn default_lane_widths_axis ( ) -> Vec < WireLaneWidth > { vec ! [ WireLaneWidth :: F32x16 ] }
1083+ fn default_residual_centroids ( ) -> u32 { 256 }
1084+
1085+ impl WireSweepGrid {
1086+ /// Product of all axis lengths.
1087+ pub fn cardinality ( & self ) -> usize {
1088+ self . subspaces . len ( )
1089+ * self . centroids . len ( )
1090+ * self . residual_depths . len ( )
1091+ * self . rotations . len ( )
1092+ * self . distances . len ( )
1093+ * self . lane_widths . len ( )
1094+ }
1095+
1096+ /// Materialize the full Cartesian product as a vec of `WireCodecParams`.
1097+ ///
1098+ /// Materialized (not lazy) because typical sweeps are ≤ ~200 candidates
1099+ /// and the client wants to validate the whole grid before streaming.
1100+ /// Phase 3 D3.1 streams the RESULTS; the grid itself is always upfront.
1101+ pub fn enumerate ( & self ) -> Vec < WireCodecParams > {
1102+ let mut out = Vec :: with_capacity ( self . cardinality ( ) ) ;
1103+ for & subs in & self . subspaces {
1104+ for & cent in & self . centroids {
1105+ for & depth in & self . residual_depths {
1106+ for rot in & self . rotations {
1107+ for & dist in & self . distances {
1108+ for & lw in & self . lane_widths {
1109+ out. push ( WireCodecParams {
1110+ subspaces : subs,
1111+ centroids : cent,
1112+ residual : WireResidualSpec {
1113+ depth,
1114+ centroids : self . residual_centroids ,
1115+ } ,
1116+ lane_width : lw,
1117+ pre_rotation : rot. clone ( ) ,
1118+ distance : dist,
1119+ calibration_rows : self . calibration_rows ,
1120+ measurement_rows : self . measurement_rows ,
1121+ seed : self . seed ,
1122+ } ) ;
1123+ }
1124+ }
1125+ }
1126+ }
1127+ }
1128+ }
1129+ out
1130+ }
1131+ }
1132+
1133+ /// `POST /v1/shader/sweep` request. Client submits one grid + a measure
1134+ /// set; server enumerates + calibrates + token-agreements each grid point.
1135+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
1136+ pub struct WireSweepRequest {
1137+ pub tensor_path : String ,
1138+ pub grid : WireSweepGrid ,
1139+ #[ serde( default = "default_measure_set" ) ]
1140+ pub measure : Vec < WireMeasure > ,
1141+ /// Optional Lance fragment path to append each result row to. None =
1142+ /// stream-only (no persistent log).
1143+ #[ serde( default ) ]
1144+ pub log_to_lance : Option < String > ,
1145+ /// Human-readable label for this sweep run (ends up in the Lance row
1146+ /// metadata + the SSE stream header).
1147+ #[ serde( default ) ]
1148+ pub label : String ,
1149+ }
1150+
1151+ fn default_measure_set ( ) -> Vec < WireMeasure > {
1152+ vec ! [
1153+ WireMeasure :: ReconstructionIccHeldOut ,
1154+ WireMeasure :: TokenAgreementTop1 ,
1155+ ]
1156+ }
1157+
1158+ /// One grid-point result, streamed by the sweep handler. Carries the
1159+ /// candidate that produced it + optional per-measure payloads.
1160+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
1161+ pub struct WireSweepResult {
1162+ /// Zero-based grid index (0 .. grid.cardinality()).
1163+ pub grid_index : u32 ,
1164+ /// The candidate codec params this row measured.
1165+ pub candidate : WireCodecParams ,
1166+ /// `CodecParams::kernel_signature()` of the kernel that was executed.
1167+ /// Multiple grid points may share a signature (JIT cache hit).
1168+ pub kernel_hash : u64 ,
1169+ /// Populated when the measure set contained Reconstruction* variants.
1170+ #[ serde( default ) ]
1171+ pub calibrate : Option < WireCalibrateResponse > ,
1172+ /// Populated when the measure set contained TokenAgreement* / PerLayerMse.
1173+ #[ serde( default ) ]
1174+ pub token_agreement : Option < WireTokenAgreementResult > ,
1175+ /// Phase 0 honesty flag — mirrors WireTokenAgreementResult.stub.
1176+ /// `true` means this row carries stub numbers, not real measurements.
1177+ #[ serde( default ) ]
1178+ pub stub : bool ,
1179+ }
1180+
1181+ /// `POST /v1/shader/sweep` response for batch (non-streaming) clients.
1182+ /// Streaming clients receive one `WireSweepResult` per SSE event instead.
1183+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
1184+ pub struct WireSweepResponse {
1185+ pub label : String ,
1186+ pub cardinality : u32 ,
1187+ pub results : Vec < WireSweepResult > ,
1188+ pub elapsed_ms : u64 ,
1189+ /// Path the results were appended to, if `log_to_lance` was set.
1190+ #[ serde( default ) ]
1191+ pub lance_fragment_path : Option < String > ,
1192+ }
1193+
10091194#[ cfg( test) ]
10101195mod tests {
10111196 use super :: * ;
@@ -1312,6 +1497,103 @@ mod tests {
13121497 assert_eq ! ( b, WireBaseline :: Passthrough ) ;
13131498 }
13141499
1500+ // ═════════════════════════════════════════════════════════════════════
1501+ // D0.3 — WireSweep tests (grid cardinality + enumerate + serde)
1502+ // ═════════════════════════════════════════════════════════════════════
1503+
1504+ #[ test]
1505+ fn sweep_grid_cardinality_is_product_of_axes ( ) {
1506+ let grid = WireSweepGrid {
1507+ subspaces : vec ! [ 6 ] ,
1508+ centroids : vec ! [ 256 , 512 , 1024 ] ,
1509+ residual_depths : vec ! [ 0 , 1 , 2 ] ,
1510+ rotations : vec ! [ WireRotation :: Identity , WireRotation :: Hadamard { dim: 4096 } ] ,
1511+ distances : vec ! [ WireDistance :: AdcU8 ] ,
1512+ lane_widths : vec ! [ WireLaneWidth :: F32x16 , WireLaneWidth :: BF16x32 ] ,
1513+ residual_centroids : 256 ,
1514+ calibration_rows : 2048 ,
1515+ measurement_rows : 0 ,
1516+ seed : 42 ,
1517+ } ;
1518+ // 1 × 3 × 3 × 2 × 1 × 2 = 36
1519+ assert_eq ! ( grid. cardinality( ) , 36 ) ;
1520+ let params = grid. enumerate ( ) ;
1521+ assert_eq ! ( params. len( ) , 36 ) ;
1522+ }
1523+
1524+ #[ test]
1525+ fn sweep_grid_enumerate_produces_all_unique_signatures ( ) {
1526+ let grid = WireSweepGrid {
1527+ subspaces : vec ! [ 6 ] ,
1528+ centroids : vec ! [ 256 , 1024 ] ,
1529+ residual_depths : vec ! [ 0 , 1 ] ,
1530+ rotations : vec ! [ WireRotation :: Identity ] ,
1531+ distances : vec ! [ WireDistance :: AdcU8 ] ,
1532+ lane_widths : vec ! [ WireLaneWidth :: F32x16 ] ,
1533+ residual_centroids : 256 ,
1534+ calibration_rows : 2048 ,
1535+ measurement_rows : 0 ,
1536+ seed : 42 ,
1537+ } ;
1538+ let params = grid. enumerate ( ) ;
1539+ let mut sigs: std:: collections:: HashSet < u64 > = std:: collections:: HashSet :: new ( ) ;
1540+ for p in & params {
1541+ let cp: CodecParams = p. clone ( ) . try_into ( ) . expect ( "valid codec params" ) ;
1542+ sigs. insert ( cp. kernel_signature ( ) ) ;
1543+ }
1544+ // 4 candidates, all with distinct (centroids × residual_depth) combos →
1545+ // 4 distinct kernel signatures (kernel_signature hashes IR-shaping fields)
1546+ assert_eq ! ( sigs. len( ) , 4 ) ;
1547+ }
1548+
1549+ #[ test]
1550+ fn sweep_grid_defaults_produce_single_candidate ( ) {
1551+ // All axes default to single-element vecs → cardinality 1.
1552+ let json = r#"{}"# ;
1553+ let grid: WireSweepGrid = serde_json:: from_str ( json) . unwrap ( ) ;
1554+ assert_eq ! ( grid. cardinality( ) , 1 ) ;
1555+ assert_eq ! ( grid. subspaces, vec![ 6 ] ) ;
1556+ assert_eq ! ( grid. centroids, vec![ 256 ] ) ;
1557+ }
1558+
1559+ #[ test]
1560+ fn sweep_request_round_trips_json ( ) {
1561+ let req = WireSweepRequest {
1562+ tensor_path : "models/qwen3-tts-0.6b/q_proj.safetensors" . to_string ( ) ,
1563+ grid : WireSweepGrid {
1564+ subspaces : vec ! [ 6 ] ,
1565+ centroids : vec ! [ 256 , 1024 ] ,
1566+ residual_depths : vec ! [ 0 ] ,
1567+ rotations : vec ! [ WireRotation :: Identity ] ,
1568+ distances : vec ! [ WireDistance :: AdcU8 ] ,
1569+ lane_widths : vec ! [ WireLaneWidth :: F32x16 ] ,
1570+ residual_centroids : 256 ,
1571+ calibration_rows : 2048 ,
1572+ measurement_rows : 0 ,
1573+ seed : 42 ,
1574+ } ,
1575+ measure : vec ! [
1576+ WireMeasure :: ReconstructionIccHeldOut ,
1577+ WireMeasure :: TokenAgreementTop1 ,
1578+ ] ,
1579+ log_to_lance : Some ( "logs/sweep_phase1.lance" . to_string ( ) ) ,
1580+ label : "phase1_initial_cross_product" . to_string ( ) ,
1581+ } ;
1582+ let json = serde_json:: to_string ( & req) . unwrap ( ) ;
1583+ let decoded: WireSweepRequest = serde_json:: from_str ( & json) . unwrap ( ) ;
1584+ assert_eq ! ( decoded. grid. cardinality( ) , 2 ) ;
1585+ assert_eq ! ( decoded. measure. len( ) , 2 ) ;
1586+ assert_eq ! ( decoded. label, "phase1_initial_cross_product" ) ;
1587+ }
1588+
1589+ #[ test]
1590+ fn sweep_measure_serializes_snake_case ( ) {
1591+ let m = WireMeasure :: ReconstructionIccHeldOut ;
1592+ assert_eq ! ( serde_json:: to_string( & m) . unwrap( ) , "\" reconstruction_icc_held_out\" " ) ;
1593+ let m: WireMeasure = serde_json:: from_str ( "\" token_agreement_top1\" " ) . unwrap ( ) ;
1594+ assert_eq ! ( m, WireMeasure :: TokenAgreementTop1 ) ;
1595+ }
1596+
13151597 #[ test]
13161598 fn wire_calibrate_request_back_compat_legacy_fields ( ) {
13171599 // Legacy payload (no `params`) still parses; defaults preserved.
0 commit comments