|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +// Regression for #63913: the ANN index writer must train the FAISS quantizer |
| 19 | +// EXACTLY ONCE per index build, not once per buffered chunk. |
| 20 | +// |
| 21 | +// Background: AnnIndexColumnWriter buffers vectors and flushes them one chunk at |
| 22 | +// a time (ann_index_build_chunk_size). The buggy code called train() on every |
| 23 | +// chunk. For a PQ (product-quantization) index, train() re-fits the codebook on |
| 24 | +// the latest chunk, but vectors from earlier chunks were already add()ed and |
| 25 | +// encoded under the previous codebook. After the final chunk re-trains, those |
| 26 | +// earlier codes no longer match the stored codebook, so they decode to garbage |
| 27 | +// distances at query time -> recall collapses on any segment that spans more |
| 28 | +// than one chunk. |
| 29 | +// |
| 30 | +// This test shrinks ann_index_build_chunk_size so a single 20k-row segment spans |
| 31 | +// 10 chunks, builds an IVF+PQ index, and asserts recall@10 (vs exact brute-force |
| 32 | +// l2_distance) stays high. On a buggy BE this recall drops to ~0.1 and the test |
| 33 | +// fails; on the fixed BE it stays high. An IVF+FLAT table loaded from the same |
| 34 | +// data is used as a positive control (FLAT has no codebook, so it is unaffected |
| 35 | +// and must reach near-exact recall) -- this proves the harness can achieve high |
| 36 | +// recall, so a low PQ recall is specifically the train-reentry bug. |
| 37 | +// |
| 38 | +// nonConcurrent: it temporarily changes a global BE config. |
| 39 | +suite("ann_ivf_pq_train_once_recall", "nonConcurrent") { |
| 40 | + def dim = 32 |
| 41 | + def nRows = 20000 |
| 42 | + def chunkSize = 2000 // 20000 / 2000 = 10 chunks per segment -> bug triggers hard |
| 43 | + def nlist = 64 |
| 44 | + def topk = 10 |
| 45 | + def nQueries = 30 |
| 46 | + def rnd = new Random(42) // fixed seed -> reproducible |
| 47 | + |
| 48 | + // -- generate i.i.d. gaussian base vectors as a stream-load CSV (id|[v0,...]) -- |
| 49 | + // The vector itself contains commas, so use '|' as the column separator. |
| 50 | + def sb = new StringBuilder() |
| 51 | + for (int i = 0; i < nRows; i++) { |
| 52 | + sb.append(i).append('|').append('[') |
| 53 | + for (int d = 0; d < dim; d++) { |
| 54 | + float v = (float) rnd.nextGaussian() |
| 55 | + if (d > 0) sb.append(',') |
| 56 | + sb.append(String.format(Locale.US, '%.6f', v)) // Locale.US: never a comma decimal |
| 57 | + } |
| 58 | + sb.append(']').append('\n') |
| 59 | + } |
| 60 | + def csv = sb.toString() |
| 61 | + |
| 62 | + // -- query vectors (independent random) -- |
| 63 | + def queries = new float[nQueries][dim] |
| 64 | + for (int q = 0; q < nQueries; q++) { |
| 65 | + for (int d = 0; d < dim; d++) { |
| 66 | + queries[q][d] = (float) rnd.nextGaussian() |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + def vecLiteral = { float[] v -> |
| 71 | + def s = new StringBuilder('[') |
| 72 | + for (int d = 0; d < v.length; d++) { |
| 73 | + if (d > 0) s.append(',') |
| 74 | + s.append(String.format(Locale.US, '%.6f', v[d])) |
| 75 | + } |
| 76 | + s.append(']') |
| 77 | + return s.toString() |
| 78 | + } |
| 79 | + |
| 80 | + def idsOf = { String q -> |
| 81 | + def rows = sql q |
| 82 | + return rows.collect { (it[0] as long) } as Set |
| 83 | + } |
| 84 | + |
| 85 | + // recall@topk averaged over all queries: approx (uses index) vs exact (brute force) |
| 86 | + def measureRecall = { String table -> |
| 87 | + double total = 0.0d |
| 88 | + for (int q = 0; q < nQueries; q++) { |
| 89 | + def lit = vecLiteral(queries[q]) |
| 90 | + def approx = idsOf("select id from ${table} order by l2_distance_approximate(vec, ${lit}) limit ${topk}".toString()) |
| 91 | + def exact = idsOf("select id from ${table} order by l2_distance(vec, ${lit}) limit ${topk}".toString()) |
| 92 | + total += (approx.intersect(exact).size() / (double) topk) |
| 93 | + } |
| 94 | + return total / nQueries |
| 95 | + } |
| 96 | + |
| 97 | + def loadCsv = { String table -> |
| 98 | + streamLoad { |
| 99 | + table "${table}" |
| 100 | + set 'column_separator', '|' |
| 101 | + set 'columns', 'id, vec' |
| 102 | + inputStream new ByteArrayInputStream(csv.getBytes("UTF-8")) |
| 103 | + time 120000 |
| 104 | + check { result, exception, startTime, endTime -> |
| 105 | + if (exception != null) { |
| 106 | + throw exception |
| 107 | + } |
| 108 | + def json = parseJson(result) |
| 109 | + assertEquals("success", json.Status.toLowerCase()) |
| 110 | + assertEquals(nRows, json.NumberLoadedRows) |
| 111 | + assertEquals(0, json.NumberFilteredRows) |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + sql "set enable_common_expr_pushdown = true" |
| 117 | + sql "set enable_ann_index_result_cache = false" // avoid cache masking real index behavior |
| 118 | + // Scan all lists so IVF coarse-quantization adds no approximation: this isolates |
| 119 | + // PQ-codebook correctness, which is what the bug breaks. |
| 120 | + sql "set ivf_nprobe = ${nlist}" |
| 121 | + |
| 122 | + setBeConfigTemporary([ann_index_build_chunk_size: chunkSize]) { |
| 123 | + // ================= IVF + PQ : the path the bug corrupts ================= |
| 124 | + sql "drop table if exists ann_pq_train_once" |
| 125 | + sql """ |
| 126 | + create table ann_pq_train_once ( |
| 127 | + id int not null, |
| 128 | + vec array<float> not null, |
| 129 | + index ann_idx (vec) using ann properties ( |
| 130 | + 'index_type' = 'ivf', |
| 131 | + 'metric_type'= 'l2_distance', |
| 132 | + 'dim' = '${dim}', |
| 133 | + 'nlist' = '${nlist}', |
| 134 | + 'quantizer' = 'pq', |
| 135 | + 'pq_m' = '16', |
| 136 | + 'pq_nbits' = '8') |
| 137 | + ) engine=olap |
| 138 | + duplicate key(id) |
| 139 | + distributed by hash(id) buckets 1 |
| 140 | + properties ('replication_num' = '1'); |
| 141 | + """ |
| 142 | + loadCsv("ann_pq_train_once") |
| 143 | + |
| 144 | + // Guard: the approximate query MUST be pushed into the ANN index. Otherwise |
| 145 | + // it degenerates to exact distances, recall would be a trivial 1.0, and the |
| 146 | + // test would silently stop guarding the bug. |
| 147 | + explain { |
| 148 | + sql "select id from ann_pq_train_once order by l2_distance_approximate(vec, ${vecLiteral(queries[0])}) limit ${topk}".toString() |
| 149 | + contains "ANN SORT INFO" |
| 150 | + } |
| 151 | + |
| 152 | + double pqRecall = measureRecall("ann_pq_train_once") |
| 153 | + logger.info("[#63913] IVF+PQ multi-chunk recall@${topk} = ${pqRecall} (chunks per segment = ${nRows / chunkSize})") |
| 154 | + // Fixed build: typically ~0.8-0.95. Buggy build (per-chunk retrain): ~0.1. |
| 155 | + // Threshold 0.5 sits in the wide gap between them. |
| 156 | + assertTrue(pqRecall >= 0.5d, |
| 157 | + ("IVF+PQ recall@${topk} = ${pqRecall} is too low. The PQ codebook was likely " + |
| 158 | + "re-trained on every chunk so earlier chunks decode against the wrong codebook " + |
| 159 | + "(regression of #63913 'train ANN index once'). Expected >= 0.5.").toString()) |
| 160 | + |
| 161 | + // ============ IVF + FLAT : positive control, must be unaffected ============ |
| 162 | + sql "drop table if exists ann_flat_control" |
| 163 | + sql """ |
| 164 | + create table ann_flat_control ( |
| 165 | + id int not null, |
| 166 | + vec array<float> not null, |
| 167 | + index ann_idx (vec) using ann properties ( |
| 168 | + 'index_type' = 'ivf', |
| 169 | + 'metric_type'= 'l2_distance', |
| 170 | + 'dim' = '${dim}', |
| 171 | + 'nlist' = '${nlist}', |
| 172 | + 'quantizer' = 'flat') |
| 173 | + ) engine=olap |
| 174 | + duplicate key(id) |
| 175 | + distributed by hash(id) buckets 1 |
| 176 | + properties ('replication_num' = '1'); |
| 177 | + """ |
| 178 | + loadCsv("ann_flat_control") |
| 179 | + |
| 180 | + double flatRecall = measureRecall("ann_flat_control") |
| 181 | + logger.info("[#63913] IVF+FLAT control recall@${topk} = ${flatRecall}") |
| 182 | + // FLAT has no codebook; with nprobe == nlist it is exact. If this is low, the |
| 183 | + // problem is the environment/harness, not the bug under test. |
| 184 | + assertTrue(flatRecall >= 0.95d, |
| 185 | + ("IVF+FLAT control recall@${topk} = ${flatRecall} is unexpectedly low; this points " + |
| 186 | + "to an environment/harness issue rather than the train-once bug.").toString()) |
| 187 | + } |
| 188 | + |
| 189 | + sql "drop table if exists ann_pq_train_once" |
| 190 | + sql "drop table if exists ann_flat_control" |
| 191 | +} |
0 commit comments