diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 8cdfb9e4f489..1b1b42ed2987 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -259,6 +259,13 @@ New Features Improvements --------------------- +* GITHUB#15948: Improve BayesianScoreQuery and LogOddsFusionQuery with base rate prior, + weighted Logarithmic Opinion Pooling, and auto parameter estimation. Add + BayesianScoreEstimator for estimating sigmoid calibration parameters from corpus + statistics. Add base rate prior to BayesianScoreQuery for log-odds space shifting. + Add per-signal weights and logit normalization to LogOddsFusionQuery. + (Jaepil Jeong) + * GITHUB#15823: Implement method to add all stream elements into a PriorityQueue. Call PriorityQueue#addAll with mapped stream in DisjunctionMaxBulkScorer's constructor. (Zhou Hui) diff --git a/lucene/core/src/java/org/apache/lucene/search/BayesianScoreEstimator.java b/lucene/core/src/java/org/apache/lucene/search/BayesianScoreEstimator.java new file mode 100644 index 000000000000..29981f9306a7 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/BayesianScoreEstimator.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.index.Term; +import org.apache.lucene.util.ArrayUtil; + +/** + * Estimates {@link BayesianScoreQuery} parameters (alpha, beta, base rate) from corpus statistics + * via pseudo-query sampling. + * + *

The estimation algorithm: + * + *

    + *
  1. Sample N documents randomly from the index + *
  2. For each document, create a pseudo-query from its first few tokens in the target field + *
  3. Run each pseudo-query via BM25 and collect the score distribution + *
  4. Estimate: beta = median(scores), alpha = 1 / std(scores) + *
  5. Estimate base rate: mean fraction of documents scoring above the 95th percentile + *
+ * + * @lucene.experimental + */ +public class BayesianScoreEstimator { + + /** Estimated parameters for {@link BayesianScoreQuery}. */ + public record Parameters(float alpha, float beta, float baseRate) {} + + private static final int DEFAULT_N_SAMPLES = 50; + private static final int DEFAULT_TOKENS_PER_QUERY = 5; + private static final double PERCENTILE_THRESHOLD = 0.95; + private static final float BASE_RATE_MIN = 1e-6f; + private static final float BASE_RATE_MAX = 0.5f; + + private BayesianScoreEstimator() {} + + /** + * Estimates BayesianScoreQuery parameters from the given index. + * + * @param searcher the index searcher to sample from + * @param field the text field to create pseudo-queries for + * @param nSamples number of documents to sample (default 50) + * @param tokensPerQuery number of tokens per pseudo-query (default 5) + * @param seed random seed for reproducible sampling + * @return estimated alpha, beta, and base rate + * @throws IOException if an I/O error occurs reading the index + */ + public static Parameters estimate( + IndexSearcher searcher, String field, int nSamples, int tokensPerQuery, long seed) + throws IOException { + IndexReader reader = searcher.getIndexReader(); + int maxDoc = reader.maxDoc(); + if (maxDoc == 0) { + return new Parameters(1.0f, 0.0f, 0.01f); + } + + nSamples = Math.min(nSamples, maxDoc); + Random rng = new Random(seed); + + // Sample document IDs + int[] sampledDocs = sampleDocIds(maxDoc, nSamples, rng); + + // Create pseudo-queries and collect scores + List allScoreArrays = new ArrayList<>(); + List baseRateFractions = new ArrayList<>(); + StoredFields storedFields = reader.storedFields(); + + for (int docId : sampledDocs) { + String fieldValue = storedFields.document(docId).get(field); + if (fieldValue == null || fieldValue.isEmpty()) { + continue; + } + + // Extract first N tokens as pseudo-query terms + String[] tokens = tokenize(fieldValue, tokensPerQuery); + if (tokens.length == 0) { + continue; + } + + // Build a BooleanQuery from the tokens + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + for (String token : tokens) { + builder.add(new TermQuery(new Term(field, token)), BooleanClause.Occur.SHOULD); + } + Query pseudoQuery = builder.build(); + + // Collect all scores + float[] scores = collectScores(searcher, pseudoQuery, maxDoc); + if (scores.length == 0) { + continue; + } + allScoreArrays.add(scores); + + // Base rate: fraction of docs above 95th percentile + float[] sorted = scores.clone(); + Arrays.sort(sorted); + int pIdx = (int) (sorted.length * PERCENTILE_THRESHOLD); + pIdx = Math.min(pIdx, sorted.length - 1); + float threshold = sorted[pIdx]; + int highCount = 0; + for (float s : scores) { + if (s >= threshold) { + highCount++; + } + } + baseRateFractions.add((float) highCount / maxDoc); + } + + if (allScoreArrays.isEmpty()) { + return new Parameters(1.0f, 0.0f, 0.01f); + } + + // Flatten all scores for global statistics + int totalScores = 0; + for (float[] arr : allScoreArrays) { + totalScores += arr.length; + } + float[] allScores = new float[totalScores]; + int offset = 0; + for (float[] arr : allScoreArrays) { + System.arraycopy(arr, 0, allScores, offset, arr.length); + offset += arr.length; + } + + // beta = median + Arrays.sort(allScores); + float beta = allScores[allScores.length / 2]; + + // alpha = 1 / std + double mean = 0; + for (float s : allScores) { + mean += s; + } + mean /= allScores.length; + double variance = 0; + for (float s : allScores) { + double diff = s - mean; + variance += diff * diff; + } + variance /= allScores.length; + double std = Math.sqrt(variance); + float alpha = std > 0 ? (float) (1.0 / std) : 1.0f; + + // base rate = mean of per-query fractions, clamped + float baseRate = 0; + for (float f : baseRateFractions) { + baseRate += f; + } + baseRate /= baseRateFractions.size(); + baseRate = Math.clamp(baseRate, BASE_RATE_MIN, BASE_RATE_MAX); + + return new Parameters(alpha, beta, baseRate); + } + + /** + * Estimates parameters with default settings (50 samples, 5 tokens per query, seed 42). + * + * @param searcher the index searcher + * @param field the text field + * @return estimated parameters + * @throws IOException if an I/O error occurs + */ + public static Parameters estimate(IndexSearcher searcher, String field) throws IOException { + return estimate(searcher, field, DEFAULT_N_SAMPLES, DEFAULT_TOKENS_PER_QUERY, 42); + } + + private static int[] sampleDocIds(int maxDoc, int nSamples, Random rng) { + // Fisher-Yates partial shuffle for sampling without replacement + int[] all = new int[maxDoc]; + for (int i = 0; i < maxDoc; i++) { + all[i] = i; + } + int n = Math.min(nSamples, maxDoc); + for (int i = 0; i < n; i++) { + int j = i + rng.nextInt(maxDoc - i); + int tmp = all[i]; + all[i] = all[j]; + all[j] = tmp; + } + return ArrayUtil.copyOfSubArray(all, 0, n); + } + + private static String[] tokenize(String text, int maxTokens) { + // Simple whitespace tokenization with lowercasing + String[] parts = text.toLowerCase(java.util.Locale.ROOT).split("\\s+"); + int n = Math.min(parts.length, maxTokens); + List tokens = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + String token = parts[i].replaceAll("[^a-z0-9]", ""); + if (token.isEmpty() == false) { + tokens.add(token); + } + } + return tokens.toArray(new String[0]); + } + + private static float[] collectScores(IndexSearcher searcher, Query query, int maxDoc) + throws IOException { + int topN = Math.min(maxDoc, 10000); + TopDocs topDocs = searcher.search(query, topN); + float[] scores = new float[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + scores[i] = topDocs.scoreDocs[i].score; + } + return scores; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/BayesianScoreQuery.java b/lucene/core/src/java/org/apache/lucene/search/BayesianScoreQuery.java index fbb7af8fb6ca..87ce0350ff97 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BayesianScoreQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BayesianScoreQuery.java @@ -30,7 +30,13 @@ * {@link LogOddsFusionQuery}. * *

The alpha parameter controls the sigmoid steepness (score sensitivity), and beta controls the - * midpoint (decision boundary). These can be set manually or estimated from the score distribution. + * midpoint (decision boundary). These can be set manually or estimated from the score distribution + * via {@link BayesianScoreEstimator}. + * + *

An optional base rate encodes the corpus-level prior probability that a random document is + * relevant to a random query. When set, the posterior is computed in log-odds space: {@code + * sigmoid(alpha * (score - beta) + logit(baseRate))}. This shifts scores down for rare-relevance + * corpora, improving calibration. * * @lucene.experimental */ @@ -39,15 +45,20 @@ public final class BayesianScoreQuery extends Query { private final Query query; private final float alpha; private final float beta; + private final float baseRate; + private final float logitBaseRate; /** - * Creates a BayesianScoreQuery. + * Creates a BayesianScoreQuery with base rate. * * @param query the inner query whose scores will be transformed * @param alpha sigmoid steepness (must be positive and finite) * @param beta sigmoid midpoint (must be finite) + * @param baseRate corpus-level relevance prior in (0, 1), or 0 to disable. When positive, adds + * logit(baseRate) to the log-odds before sigmoid, shifting scores to account for the rarity + * of relevant documents. */ - public BayesianScoreQuery(Query query, float alpha, float beta) { + public BayesianScoreQuery(Query query, float alpha, float beta, float baseRate) { this.query = Objects.requireNonNull(query); if (Float.isFinite(alpha) == false || alpha <= 0) { throw new IllegalArgumentException("alpha must be a positive finite value, got " + alpha); @@ -55,8 +66,28 @@ public BayesianScoreQuery(Query query, float alpha, float beta) { if (Float.isFinite(beta) == false) { throw new IllegalArgumentException("beta must be a finite value, got " + beta); } + if (baseRate < 0 || baseRate >= 1) { + throw new IllegalArgumentException("baseRate must be in [0, 1), got " + baseRate); + } this.alpha = alpha; this.beta = beta; + this.baseRate = baseRate; + if (baseRate > 0) { + this.logitBaseRate = (float) Math.log(baseRate / (1.0 - baseRate)); + } else { + this.logitBaseRate = 0f; + } + } + + /** + * Creates a BayesianScoreQuery without base rate. + * + * @param query the inner query whose scores will be transformed + * @param alpha sigmoid steepness (must be positive and finite) + * @param beta sigmoid midpoint (must be finite) + */ + public BayesianScoreQuery(Query query, float alpha, float beta) { + this(query, alpha, beta, 0f); } /** Returns the wrapped query. */ @@ -74,6 +105,11 @@ public float getBeta() { return beta; } + /** Returns the base rate, or 0 if not set. */ + public float getBaseRate() { + return baseRate; + } + static float sigmoid(float x) { if (x >= 0) { return (float) (1.0 / (1.0 + Math.exp(-x))); @@ -100,7 +136,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return rewritten; } if (rewritten != query) { - return new BayesianScoreQuery(rewritten, alpha, beta); + return new BayesianScoreQuery(rewritten, alpha, beta, baseRate); } return super.rewrite(indexSearcher); } @@ -112,7 +148,11 @@ public void visit(QueryVisitor visitor) { @Override public String toString(String field) { - return "BayesianScore(" + query.toString(field) + ", alpha=" + alpha + ", beta=" + beta + ")"; + String base = "BayesianScore(" + query.toString(field) + ", alpha=" + alpha + ", beta=" + beta; + if (baseRate > 0) { + base += ", baseRate=" + baseRate; + } + return base + ")"; } @Override @@ -123,7 +163,8 @@ public boolean equals(Object other) { private boolean equalsTo(BayesianScoreQuery other) { return query.equals(other.query) && Float.floatToIntBits(alpha) == Float.floatToIntBits(other.alpha) - && Float.floatToIntBits(beta) == Float.floatToIntBits(other.beta); + && Float.floatToIntBits(beta) == Float.floatToIntBits(other.beta) + && Float.floatToIntBits(baseRate) == Float.floatToIntBits(other.baseRate); } @Override @@ -132,6 +173,7 @@ public int hashCode() { h = 31 * h + query.hashCode(); h = 31 * h + Float.floatToIntBits(alpha); h = 31 * h + Float.floatToIntBits(beta); + h = 31 * h + Float.floatToIntBits(baseRate); return h; } @@ -155,7 +197,18 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio return innerExpl; } float innerScore = innerExpl.getValue().floatValue(); - float transformed = sigmoid(alpha * (innerScore - beta)); + float logOdds = alpha * (innerScore - beta) + logitBaseRate; + float transformed = sigmoid(logOdds); + if (baseRate > 0) { + return Explanation.match( + transformed, + "sigmoid calibration with base rate, computed as" + + " sigmoid(alpha * (score - beta) + logit(baseRate)) from:", + innerExpl, + Explanation.match(alpha, "alpha, sigmoid steepness"), + Explanation.match(beta, "beta, sigmoid midpoint"), + Explanation.match(baseRate, "baseRate, corpus-level relevance prior")); + } return Explanation.match( transformed, "sigmoid calibration, computed as sigmoid(alpha * (score - beta)) from:", @@ -204,7 +257,7 @@ private class BayesianScoreScorer extends FilterScorer { @Override public float score() throws IOException { float innerScore = in.score(); - return sigmoid(alpha * (innerScore - beta)); + return sigmoid(alpha * (innerScore - beta) + logitBaseRate); } @Override @@ -216,19 +269,19 @@ public int advanceShallow(int target) throws IOException { public float getMaxScore(int upTo) throws IOException { float innerMax = in.getMaxScore(upTo); // sigmoid is monotone, so max(sigmoid(f(x))) = sigmoid(max(f(x))) - return sigmoid(alpha * (innerMax - beta)); + return sigmoid(alpha * (innerMax - beta) + logitBaseRate); } @Override public void setMinCompetitiveScore(float minScore) throws IOException { // Invert the sigmoid to get the minimum inner score needed: - // minScore = sigmoid(alpha * (innerScore - beta)) - // => alpha * (innerScore - beta) = logit(minScore) - // => innerScore = logit(minScore) / alpha + beta + // minScore = sigmoid(alpha * (innerScore - beta) + logitBaseRate) + // => alpha * (innerScore - beta) + logitBaseRate = logit(minScore) + // => innerScore = (logit(minScore) - logitBaseRate) / alpha + beta if (minScore > 0f && minScore < 1f) { float clamped = Math.max(1e-7f, Math.min(1f - 1e-7f, minScore)); float logitMin = (float) Math.log(clamped / (1f - clamped)); - float innerMin = logitMin / alpha + beta; + float innerMin = (logitMin - logitBaseRate) / alpha + beta; in.setMinCompetitiveScore(Math.max(0f, innerMin)); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionQuery.java b/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionQuery.java index ad1d26179ccf..5fb575d5f288 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionQuery.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; @@ -43,6 +44,11 @@ *

The alpha parameter controls the confidence scaling exponent. The default alpha=0.5 implements * the sqrt(n) scaling law from "From Bayesian Inference to Neural Computation". * + *

Optional per-signal weights enable weighted Log-OP (Logarithmic Opinion Pooling) where each + * signal's log-odds contribution is scaled by its reliability weight. Weights must be non-negative + * and sum to 1. When weights are provided, the scoring formula becomes: {@code sigmoid(n^alpha * + * sum(w_i * softplus(logit(p_i))))} instead of the uniform mean. + * * @see LogOddsFusionScorer * @lucene.experimental */ @@ -51,31 +57,101 @@ public final class LogOddsFusionQuery extends Query implements Iterable { private final Multiset clauses = new Multiset<>(); private final List orderedClauses; private final float alpha; + private final float[] signalWeights; + private final float[] logitMin; + private final float[] logitMax; /** - * Creates a new LogOddsFusionQuery. + * Creates a new LogOddsFusionQuery with per-signal weights and optional logit normalization. * * @param clauses the sub-queries to combine * @param alpha confidence scaling exponent (0.5 = sqrt(n) law) - * @throws IllegalArgumentException if alpha is not in [0, 1] + * @param weights per-signal weights (must be non-negative, finite, and sum to 1.0), or null for + * uniform weighting + * @param logitMin per-signal logit lower bounds for normalization, or null to use softplus gating + * @param logitMax per-signal logit upper bounds for normalization, or null to use softplus gating + * @throws IllegalArgumentException if alpha is not in [0, 1], or weights/bounds are invalid */ - public LogOddsFusionQuery(Collection clauses, float alpha) { + public LogOddsFusionQuery( + Collection clauses, + float alpha, + float[] weights, + float[] logitMin, + float[] logitMax) { Objects.requireNonNull(clauses, "Collection of Queries must not be null"); if (Float.isNaN(alpha) || alpha < 0 || alpha > 1) { throw new IllegalArgumentException("alpha must be in [0, 1], got " + alpha); } + if (weights != null) { + if (weights.length != clauses.size()) { + throw new IllegalArgumentException( + "weights length " + weights.length + " must equal clauses size " + clauses.size()); + } + float sum = 0; + for (float w : weights) { + if (Float.isFinite(w) == false || w < 0) { + throw new IllegalArgumentException("weights must be non-negative and finite, got " + w); + } + sum += w; + } + if (Math.abs(sum - 1.0f) > 1e-3f) { + throw new IllegalArgumentException("weights must sum to 1.0, got " + sum); + } + this.signalWeights = weights.clone(); + } else { + this.signalWeights = null; + } + if (logitMin != null && logitMax != null) { + if (logitMin.length != clauses.size()) { + throw new IllegalArgumentException( + "logitMin length " + logitMin.length + " must equal clauses size " + clauses.size()); + } + if (logitMax.length != clauses.size()) { + throw new IllegalArgumentException( + "logitMax length " + logitMax.length + " must equal clauses size " + clauses.size()); + } + this.logitMin = logitMin.clone(); + this.logitMax = logitMax.clone(); + } else { + this.logitMin = null; + this.logitMax = null; + } this.alpha = alpha; this.clauses.addAll(clauses); this.orderedClauses = new ArrayList<>(clauses); } /** - * Creates a new LogOddsFusionQuery with default alpha=0.5 (sqrt(n) scaling law). + * Creates a new LogOddsFusionQuery with per-signal weights (softplus gating, no normalization). + * + * @param clauses the sub-queries to combine + * @param alpha confidence scaling exponent (0.5 = sqrt(n) law) + * @param weights per-signal weights, or null for uniform weighting + * @throws IllegalArgumentException if alpha is not in [0, 1], or weights are invalid + */ + public LogOddsFusionQuery(Collection clauses, float alpha, float[] weights) { + this(clauses, alpha, weights, null, null); + } + + /** + * Creates a new LogOddsFusionQuery with uniform weighting and softplus gating. + * + * @param clauses the sub-queries to combine + * @param alpha confidence scaling exponent (0.5 = sqrt(n) law) + * @throws IllegalArgumentException if alpha is not in [0, 1] + */ + public LogOddsFusionQuery(Collection clauses, float alpha) { + this(clauses, alpha, null, null, null); + } + + /** + * Creates a new LogOddsFusionQuery with default alpha=0.5, uniform weighting, and softplus + * gating. * * @param clauses the sub-queries to combine */ public LogOddsFusionQuery(Collection clauses) { - this(clauses, 0.5f); + this(clauses, 0.5f, null, null, null); } @Override @@ -93,6 +169,16 @@ public float getAlpha() { return alpha; } + /** + * Returns a copy of the per-signal weights, or null if uniform weighting is used. + * + *

When non-null, the i-th element is the weight for the i-th clause in the order returned by + * {@link #getClauses()}. + */ + public float[] getWeights() { + return signalWeights != null ? signalWeights.clone() : null; + } + /** Weight for LogOddsFusionQuery. */ protected class LogOddsFusionWeight extends Weight { @@ -102,7 +188,7 @@ protected class LogOddsFusionWeight extends Weight { public LogOddsFusionWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { super(LogOddsFusionQuery.this); - for (Query clauseQuery : clauses) { + for (Query clauseQuery : orderedClauses) { weights.add(searcher.createWeight(clauseQuery, scoreMode, boost)); } this.scoreMode = scoreMode; @@ -123,10 +209,21 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { List scorerSuppliers = new ArrayList<>(); - for (Weight w : weights) { - ScorerSupplier ss = w.scorerSupplier(context); + List activeWeightsList = signalWeights != null ? new ArrayList<>() : null; + List activeLogitMinList = logitMin != null ? new ArrayList<>() : null; + List activeLogitMaxList = logitMax != null ? new ArrayList<>() : null; + + for (int i = 0; i < weights.size(); i++) { + ScorerSupplier ss = weights.get(i).scorerSupplier(context); if (ss != null) { scorerSuppliers.add(ss); + if (activeWeightsList != null) { + activeWeightsList.add(signalWeights[i]); + } + if (activeLogitMinList != null) { + activeLogitMinList.add(logitMin[i]); + activeLogitMaxList.add(logitMax[i]); + } } } @@ -136,6 +233,10 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return scorerSuppliers.get(0); } else { final int totalClauses = clauses.size(); + final float[] activeWeights = toFloatArray(activeWeightsList); + final float[] activeMin = toFloatArray(activeLogitMinList); + final float[] activeMax = toFloatArray(activeLogitMaxList); + return new ScorerSupplier() { private long cost = -1; @@ -146,7 +247,15 @@ public Scorer get(long leadCost) throws IOException { for (ScorerSupplier ss : scorerSuppliers) { scorers.add(ss.get(leadCost)); } - return new LogOddsFusionScorer(scorers, totalClauses, alpha, scoreMode, leadCost); + return new LogOddsFusionScorer( + scorers, + totalClauses, + alpha, + activeWeights, + activeMin, + activeMax, + scoreMode, + leadCost); } @Override @@ -191,27 +300,44 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio double logitSum = 0; int totalClauses = weights.size(); - for (Weight wt : weights) { - Explanation e = wt.explain(context, doc); + for (int i = 0; i < weights.size(); i++) { + Explanation e = weights.get(i).explain(context, doc); if (e.isMatch()) { match = true; subsOnMatch.add(e); float subScore = e.getValue().floatValue(); - logitSum += LogOddsFusionScorer.softplus(LogOddsFusionScorer.logit(subScore)); + float rawLogit = LogOddsFusionScorer.logit(subScore); + float gated; + if (logitMin != null) { + float range = logitMax[i] - logitMin[i]; + gated = range > 0 ? Math.clamp((rawLogit - logitMin[i]) / range, 0f, 1f) : 0.5f; + } else { + gated = LogOddsFusionScorer.softplus(rawLogit); + } + if (signalWeights != null) { + logitSum += signalWeights[i] * gated; + } else { + logitSum += gated; + } } else if (match == false) { subsOnNoMatch.add(e); } } if (match) { - // Non-matching contribute logit(0.5) = 0 - float meanLogit = (float) (logitSum / totalClauses); float scalingFactor = (float) Math.pow(totalClauses, alpha); - float scaledLogit = meanLogit * scalingFactor; + float scaledLogit = 0f; + String description; + if (signalWeights != null) { + scaledLogit = (float) logitSum * scalingFactor; + description = + "weighted log-odds fusion, computed as sigmoid(weightedLogit * n^alpha) from:"; + } else { + scaledLogit = (float) (logitSum / totalClauses) * scalingFactor; + description = "log-odds fusion, computed as sigmoid(meanLogit * n^alpha) from:"; + } float score = LogOddsFusionScorer.sigmoid(scaledLogit); - - return Explanation.match( - score, "log-odds fusion, computed as sigmoid(meanLogit * n^alpha) from:", subsOnMatch); + return Explanation.match(score, description, subsOnMatch); } else { return Explanation.noMatch("No matching clause", subsOnNoMatch); } @@ -231,19 +357,19 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { } if (clauses.size() == 1) { - return clauses.iterator().next(); + return orderedClauses.get(0); } boolean actuallyRewritten = false; List rewrittenClauses = new ArrayList<>(); - for (Query sub : clauses) { + for (Query sub : orderedClauses) { Query rewrittenSub = sub.rewrite(indexSearcher); actuallyRewritten |= rewrittenSub != sub; rewrittenClauses.add(rewrittenSub); } if (actuallyRewritten) { - return new LogOddsFusionQuery(rewrittenClauses, alpha); + return new LogOddsFusionQuery(rewrittenClauses, alpha, signalWeights, logitMin, logitMax); } return super.rewrite(indexSearcher); @@ -259,15 +385,20 @@ public void visit(QueryVisitor visitor) { @Override public String toString(String field) { - return this.orderedClauses.stream() - .map( - subquery -> { - if (subquery instanceof BooleanQuery) { - return "(" + subquery.toString(field) + ")"; - } - return subquery.toString(field); - }) - .collect(Collectors.joining(" & ", "LogOdds(", ")^" + alpha)); + String base = + this.orderedClauses.stream() + .map( + subquery -> { + if (subquery instanceof BooleanQuery) { + return "(" + subquery.toString(field) + ")"; + } + return subquery.toString(field); + }) + .collect(Collectors.joining(" & ", "LogOdds(", ")^" + alpha)); + if (signalWeights != null) { + return base + " w=" + Arrays.toString(signalWeights); + } + return base; } @Override @@ -276,7 +407,11 @@ public boolean equals(Object other) { } private boolean equalsTo(LogOddsFusionQuery other) { - return alpha == other.alpha && Objects.equals(clauses, other.clauses); + return alpha == other.alpha + && Objects.equals(clauses, other.clauses) + && Arrays.equals(signalWeights, other.signalWeights) + && Arrays.equals(logitMin, other.logitMin) + && Arrays.equals(logitMax, other.logitMax); } @Override @@ -284,6 +419,20 @@ public int hashCode() { int h = classHash(); h = 31 * h + Float.floatToIntBits(alpha); h = 31 * h + Objects.hashCode(clauses); + h = 31 * h + Arrays.hashCode(signalWeights); + h = 31 * h + Arrays.hashCode(logitMin); + h = 31 * h + Arrays.hashCode(logitMax); return h; } + + private static float[] toFloatArray(List list) { + if (list == null) { + return null; + } + float[] result = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + result[i] = list.get(i); + } + return result; + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionScorer.java b/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionScorer.java index cb8ee8aa46de..bfbc401bc42f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionScorer.java @@ -17,6 +17,7 @@ package org.apache.lucene.search; import java.io.IOException; +import java.util.IdentityHashMap; import java.util.List; /** @@ -50,6 +51,10 @@ final class LogOddsFusionScorer extends DisjunctionScorer { private final List subScorers; private final int totalClauses; private final float scalingFactor; + private final float[] signalWeights; + private final float[] logitMin; + private final float[] logitMax; + private final IdentityHashMap scorerIndexMap; private final DisjunctionScoreBlockBoundaryPropagator disjunctionBlockPropagator; @@ -59,16 +64,42 @@ final class LogOddsFusionScorer extends DisjunctionScorer { * @param subScorers the sub scorers to combine * @param totalClauses the total number of clauses (including non-matching) * @param alpha confidence scaling exponent (0.5 = sqrt(n) law) + * @param signalWeights per-signal weights parallel to subScorers (null for uniform weighting). + * When provided, the scoring formula uses weighted sum instead of mean. Weights must be + * non-negative and should sum to 1. + * @param logitMin per-signal logit lower bounds for normalization (null to use softplus gating). + * When provided together with logitMax, logit values are normalized to [0, 1] instead of + * applying softplus. This ensures non-negative contributions while preserving learned signal + * scale calibration. + * @param logitMax per-signal logit upper bounds for normalization (null to use softplus gating) * @param scoreMode the score mode * @param leadCost the lead cost for iteration */ LogOddsFusionScorer( - List subScorers, int totalClauses, float alpha, ScoreMode scoreMode, long leadCost) + List subScorers, + int totalClauses, + float alpha, + float[] signalWeights, + float[] logitMin, + float[] logitMax, + ScoreMode scoreMode, + long leadCost) throws IOException { super(subScorers, scoreMode, leadCost); this.subScorers = subScorers; this.totalClauses = totalClauses; this.scalingFactor = (float) Math.pow(totalClauses, alpha); + this.signalWeights = signalWeights; + this.logitMin = logitMin; + this.logitMax = logitMax; + if (signalWeights != null) { + this.scorerIndexMap = new IdentityHashMap<>(subScorers.size()); + for (int i = 0; i < subScorers.size(); i++) { + this.scorerIndexMap.put(subScorers.get(i), i); + } + } else { + this.scorerIndexMap = null; + } if (scoreMode == ScoreMode.TOP_SCORES) { this.disjunctionBlockPropagator = new DisjunctionScoreBlockBoundaryPropagator(subScorers); } else { @@ -108,19 +139,40 @@ static float softplus(float x) { return (float) Math.log1p(Math.exp(x)); } + /** Applies gating to a logit value: normalization if bounds are set, softplus otherwise. */ + private float gateLogit(float rawLogit, int signalIndex) { + if (logitMin != null) { + float range = logitMax[signalIndex] - logitMin[signalIndex]; + if (range > 0) { + return Math.clamp((rawLogit - logitMin[signalIndex]) / range, 0f, 1f); + } + return 0.5f; + } + return softplus(rawLogit); + } + @Override protected float score(DisiWrapper topList) throws IOException { double logitSum = 0; for (DisiWrapper w = topList; w != null; w = w.next) { float subScore = w.scorable.score(); - logitSum += softplus(logit(subScore)); + int idx = scorerIndexMap != null ? scorerIndexMap.get(w.scorer) : -1; + float gated = gateLogit(logit(subScore), idx >= 0 ? idx : 0); + if (scorerIndexMap != null) { + logitSum += signalWeights[idx] * gated; + } else { + logitSum += gated; + } + } + // Non-matching sub-scorers contribute 0. + // With weights: sum(w_i * gated_i) already accounts for the 1/n factor. + // Without weights: divide by totalClauses to compute the mean. + float scaledLogit = 0f; + if (signalWeights != null) { + scaledLogit = (float) logitSum * scalingFactor; + } else { + scaledLogit = (float) (logitSum / totalClauses) * scalingFactor; } - // Non-matching sub-scorers contribute logit(0.5) = 0, softplus(0) = log(2) ~ 0.693. - // But we do NOT add this for non-matching scorers: they contribute 0, not softplus(0). - // This is the key distinction: a match always contributes softplus(logit) > 0, - // while a non-match contributes exactly 0. - float meanLogit = (float) (logitSum / totalClauses); - float scaledLogit = meanLogit * scalingFactor; return sigmoid(scaledLogit); } @@ -134,17 +186,27 @@ public int advanceShallow(int target) throws IOException { @Override public float getMaxScore(int upTo) throws IOException { + // Safe upper bound: gateLogit is monotone in p (both softplus and normalize are monotone), + // weights are non-negative, sum of upper bounds >= sum of actuals, and sigmoid is monotone. double maxLogitSum = 0; - for (Scorer scorer : subScorers) { + for (int i = 0; i < subScorers.size(); i++) { + Scorer scorer = subScorers.get(i); if (scorer.docID() <= upTo) { float maxSubScore = scorer.getMaxScore(upTo); - maxLogitSum += softplus(logit(maxSubScore)); + float gated = gateLogit(logit(maxSubScore), i); + if (signalWeights != null) { + maxLogitSum += signalWeights[i] * gated; + } else { + maxLogitSum += gated; + } } } - // Safe upper bound: softplus(logit) is monotone in p, sum of upper bounds >= sum of actuals, - // sigmoid is monotone - float meanLogit = (float) (maxLogitSum / totalClauses); - float scaledLogit = meanLogit * scalingFactor; + float scaledLogit = 0f; + if (signalWeights != null) { + scaledLogit = (float) maxLogitSum * scalingFactor; + } else { + scaledLogit = (float) (maxLogitSum / totalClauses) * scalingFactor; + } return sigmoid(scaledLogit); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBayesianScoreQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestBayesianScoreQuery.java index 6cc086ceb300..de77c732f8f6 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBayesianScoreQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBayesianScoreQuery.java @@ -257,4 +257,130 @@ public void testDifferentAlphaBeta() throws Exception { // Same ranking order (sigmoid is monotone regardless of alpha/beta) assertEquals("same top doc", gentleHits[0].doc, steepHits[0].doc); } + + // ---- Base rate tests ---- + + public void testBaseRateLowersScores() throws Exception { + Query inner = new TermQuery(new Term("body", "alpha")); + + BayesianScoreQuery noBaseRate = new BayesianScoreQuery(inner, 0.5f, 3.0f); + BayesianScoreQuery withBaseRate = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f); + + ScoreDoc[] hitsNo = searcher.search(noBaseRate, 10).scoreDocs; + ScoreDoc[] hitsBR = searcher.search(withBaseRate, 10).scoreDocs; + + assertTrue("both should have hits", hitsNo.length > 0 && hitsBR.length > 0); + + // Same ranking order (baseRate is a constant shift in log-odds, preserves monotonicity) + assertEquals("same top doc", hitsNo[0].doc, hitsBR[0].doc); + + // Base rate < 0.5 adds negative logit, so scores should be lower + for (int i = 0; i < Math.min(hitsNo.length, hitsBR.length); i++) { + if (hitsNo[i].doc == hitsBR[i].doc) { + assertTrue( + "base rate 0.01 should lower score: " + hitsBR[i].score + " < " + hitsNo[i].score, + hitsBR[i].score < hitsNo[i].score); + } + } + } + + public void testBaseRateScoresInRange() throws Exception { + Query inner = new TermQuery(new Term("body", "alpha")); + BayesianScoreQuery bsq = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.001f); + + ScoreDoc[] hits = searcher.search(bsq, 10).scoreDocs; + for (ScoreDoc hit : hits) { + assertTrue("score in (0,1): " + hit.score, hit.score > 0 && hit.score < 1); + } + } + + public void testBaseRateMaxScoreCorrectness() throws Exception { + Query inner = new TermQuery(new Term("body", "alpha")); + BayesianScoreQuery bsq = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f); + CheckHits.checkTopScores(random(), bsq, searcher); + } + + public void testBaseRateExplanation() throws Exception { + Query inner = new TermQuery(new Term("body", "alpha")); + BayesianScoreQuery bsq = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.05f); + + Weight w = searcher.createWeight(searcher.rewrite(bsq), ScoreMode.COMPLETE, 1); + LeafReaderContext ctx = searcher.getIndexReader().leaves().get(0); + Explanation expl = w.explain(ctx, 0); + assertTrue("should match", expl.isMatch()); + assertTrue("should mention base rate", expl.getDescription().contains("base rate")); + } + + public void testBaseRateQueryUtils() throws Exception { + Query inner = new TermQuery(new Term("body", "alpha")); + BayesianScoreQuery bsq = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f); + QueryUtils.check(random(), bsq, searcher); + } + + public void testBaseRateEqualsAndHashCode() { + Query inner = new TermQuery(new Term("body", "alpha")); + BayesianScoreQuery a = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f); + BayesianScoreQuery b = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f); + BayesianScoreQuery c = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.05f); + BayesianScoreQuery d = new BayesianScoreQuery(inner, 0.5f, 3.0f); + + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + assertNotEquals(a, c); + assertNotEquals(a, d); + } + + public void testIllegalBaseRate() { + Query inner = new TermQuery(new Term("body", "alpha")); + expectThrows( + IllegalArgumentException.class, () -> new BayesianScoreQuery(inner, 0.5f, 3.0f, -0.1f)); + expectThrows( + IllegalArgumentException.class, () -> new BayesianScoreQuery(inner, 0.5f, 3.0f, 1.0f)); + } + + // ---- Auto-estimation tests ---- + + public void testEstimatorReturnsFiniteValues() throws Exception { + BayesianScoreEstimator.Parameters params = BayesianScoreEstimator.estimate(searcher, "body"); + + assertTrue( + "alpha should be positive and finite", + params.alpha() > 0 && Float.isFinite(params.alpha())); + assertTrue("beta should be finite", Float.isFinite(params.beta())); + assertTrue("baseRate in (0, 0.5]", params.baseRate() > 0 && params.baseRate() <= 0.5f); + } + + public void testEstimatedParametersProduceValidScores() throws Exception { + BayesianScoreEstimator.Parameters params = BayesianScoreEstimator.estimate(searcher, "body"); + + Query inner = new TermQuery(new Term("body", "alpha")); + BayesianScoreQuery bsq = + new BayesianScoreQuery(inner, params.alpha(), params.beta(), params.baseRate()); + + ScoreDoc[] hits = searcher.search(bsq, 10).scoreDocs; + assertTrue("should have hits", hits.length > 0); + for (ScoreDoc hit : hits) { + assertTrue("score in (0,1): " + hit.score, hit.score > 0 && hit.score < 1); + } + } + + public void testEstimatedMaxScoreCorrectness() throws Exception { + BayesianScoreEstimator.Parameters params = BayesianScoreEstimator.estimate(searcher, "body"); + + Query inner = new TermQuery(new Term("body", "alpha")); + BayesianScoreQuery bsq = + new BayesianScoreQuery(inner, params.alpha(), params.beta(), params.baseRate()); + CheckHits.checkTopScores(random(), bsq, searcher); + } + + public void testEstimatorReproducibleWithSeed() throws Exception { + BayesianScoreEstimator.Parameters p1 = + BayesianScoreEstimator.estimate(searcher, "body", 20, 3, 123); + BayesianScoreEstimator.Parameters p2 = + BayesianScoreEstimator.estimate(searcher, "body", 20, 3, 123); + + assertEquals("same seed should produce same alpha", p1.alpha(), p2.alpha(), 0f); + assertEquals("same seed should produce same beta", p1.beta(), p2.beta(), 0f); + assertEquals("same seed should produce same baseRate", p1.baseRate(), p2.baseRate(), 0f); + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLogOddsFusionQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestLogOddsFusionQuery.java index d3f1f9c653e2..e5d1b59eff90 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestLogOddsFusionQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestLogOddsFusionQuery.java @@ -18,6 +18,7 @@ import java.util.Arrays; import java.util.Collections; +import java.util.List; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; @@ -668,6 +669,210 @@ public void testVectorNot() throws Exception { } } + // ---- Weighted fusion tests ---- + + public void testWeightedFusion() throws Exception { + Query q1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query q2 = bayesian(new TermQuery(new Term("body", "beta"))); + + LogOddsFusionQuery loq = + new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f}); + QueryUtils.check(random(), loq, searcher); + + ScoreDoc[] hits = searcher.search(loq, 10).scoreDocs; + assertTrue("should have at least 1 hit", hits.length >= 1); + for (ScoreDoc hit : hits) { + assertTrue("score should be > 0", hit.score > 0); + assertTrue("score should be < 1", hit.score < 1); + } + } + + public void testWeightedFusionAffectsRanking() throws Exception { + Query qAlpha = bayesian(new TermQuery(new Term("body", "alpha"))); + Query qBeta = bayesian(new TermQuery(new Term("body", "beta"))); + + // doc0: alpha beta gamma -> matches both + // doc1: alpha gamma delta -> matches alpha only + // doc2: beta gamma delta -> matches beta only + + // Heavy weight on alpha: doc1 (alpha-only) should rank above doc2 (beta-only) + LogOddsFusionQuery alphaHeavy = + new LogOddsFusionQuery(Arrays.asList(qAlpha, qBeta), 0.5f, new float[] {0.9f, 0.1f}); + ScoreDoc[] hitsAlpha = searcher.search(alphaHeavy, 10).scoreDocs; + + // Heavy weight on beta: doc2 (beta-only) should rank above doc1 (alpha-only) + LogOddsFusionQuery betaHeavy = + new LogOddsFusionQuery(Arrays.asList(qAlpha, qBeta), 0.5f, new float[] {0.1f, 0.9f}); + ScoreDoc[] hitsBeta = searcher.search(betaHeavy, 10).scoreDocs; + + // Find scores of doc1 and doc2 in each case + float doc1AlphaHeavy = 0, doc2AlphaHeavy = 0; + float doc1BetaHeavy = 0, doc2BetaHeavy = 0; + for (ScoreDoc hit : hitsAlpha) { + if (hit.doc == 1) doc1AlphaHeavy = hit.score; + if (hit.doc == 2) doc2AlphaHeavy = hit.score; + } + for (ScoreDoc hit : hitsBeta) { + if (hit.doc == 1) doc1BetaHeavy = hit.score; + if (hit.doc == 2) doc2BetaHeavy = hit.score; + } + + assertTrue("alpha-heavy: doc1 (alpha) > doc2 (beta)", doc1AlphaHeavy > doc2AlphaHeavy); + assertTrue("beta-heavy: doc2 (beta) > doc1 (alpha)", doc2BetaHeavy > doc1BetaHeavy); + } + + public void testWeightedExplanation() throws Exception { + Query q1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query q2 = bayesian(new TermQuery(new Term("body", "beta"))); + LogOddsFusionQuery loq = + new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.6f, 0.4f}); + + Weight w = searcher.createWeight(searcher.rewrite(loq), ScoreMode.COMPLETE, 1); + LeafReaderContext context = searcher.getIndexReader().leaves().get(0); + + Explanation expl = w.explain(context, 0); + assertTrue("doc0 should match", expl.isMatch()); + assertTrue("should mention weighted", expl.getDescription().contains("weighted")); + } + + public void testWeightedEqualsAndHashCode() { + Query q1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query q2 = bayesian(new TermQuery(new Term("body", "beta"))); + + LogOddsFusionQuery a = + new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f}); + LogOddsFusionQuery b = + new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f}); + LogOddsFusionQuery c = + new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.3f, 0.7f}); + LogOddsFusionQuery d = new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f); + + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + assertNotEquals(a, c); + assertNotEquals(a, d); + } + + public void testWeightedMaxScoreCorrectness() throws Exception { + Query q1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query q2 = bayesian(new TermQuery(new Term("body", "beta"))); + LogOddsFusionQuery loq = + new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f}); + CheckHits.checkTopScores(random(), loq, searcher); + } + + public void testWeightedToString() { + Query q1 = new TermQuery(new Term("body", "alpha")); + Query q2 = new TermQuery(new Term("body", "beta")); + LogOddsFusionQuery loq = + new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f}); + String str = loq.toString("body"); + assertTrue("should contain LogOdds", str.contains("LogOdds")); + assertTrue("should contain w=", str.contains("w=")); + assertTrue("should contain 0.7", str.contains("0.7")); + } + + public void testWeightedRewrite() throws Exception { + Query sub1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query sub2 = bayesian(new TermQuery(new Term("body", "beta"))); + LogOddsFusionQuery loq = + new LogOddsFusionQuery(Arrays.asList(sub1, sub2), 0.5f, new float[] {0.6f, 0.4f}); + + Query rewritten = searcher.rewrite(loq); + assertTrue( + "weighted rewrite should produce LogOddsFusionQuery", + rewritten instanceof LogOddsFusionQuery); + LogOddsFusionQuery rewrittenLoq = (LogOddsFusionQuery) rewritten; + assertNotNull("weights should be preserved", rewrittenLoq.getWeights()); + assertEquals(0.6f, rewrittenLoq.getWeights()[0], 1e-6f); + assertEquals(0.4f, rewrittenLoq.getWeights()[1], 1e-6f); + } + + public void testIllegalWeights() { + Query q1 = new TermQuery(new Term("body", "alpha")); + Query q2 = new TermQuery(new Term("body", "beta")); + List clauses = Arrays.asList(q1, q2); + + // Wrong length + expectThrows( + IllegalArgumentException.class, + () -> new LogOddsFusionQuery(clauses, 0.5f, new float[] {1.0f})); + + // Negative weight + expectThrows( + IllegalArgumentException.class, + () -> new LogOddsFusionQuery(clauses, 0.5f, new float[] {-0.1f, 1.1f})); + + // NaN weight + expectThrows( + IllegalArgumentException.class, + () -> new LogOddsFusionQuery(clauses, 0.5f, new float[] {Float.NaN, 0.5f})); + + // Sum != 1.0 + expectThrows( + IllegalArgumentException.class, + () -> new LogOddsFusionQuery(clauses, 0.5f, new float[] {0.3f, 0.3f})); + } + + public void testWeightedQueryUtils() throws Exception { + Query q1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query q2 = bayesian(new TermQuery(new Term("body", "beta"))); + LogOddsFusionQuery loq = + new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f}); + QueryUtils.check(random(), loq, searcher); + } + + public void testWeightedThreeWayCombination() throws Exception { + Query q1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query q2 = bayesian(new TermQuery(new Term("body", "beta"))); + Query q3 = bayesian(new TermQuery(new Term("body", "gamma"))); + LogOddsFusionQuery loq = + new LogOddsFusionQuery(Arrays.asList(q1, q2, q3), 0.5f, new float[] {0.5f, 0.3f, 0.2f}); + QueryUtils.check(random(), loq, searcher); + + ScoreDoc[] hits = searcher.search(loq, 10).scoreDocs; + assertTrue("should have hits", hits.length > 0); + for (ScoreDoc hit : hits) { + assertTrue("score in (0,1)", hit.score > 0 && hit.score < 1); + } + } + + // ---- Logit normalization tests ---- + + public void testNormalizedFusionProducesValidScores() throws Exception { + Query q1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query q2 = bayesian(new TermQuery(new Term("body", "beta"))); + + // logit bounds: signal 0 range [-3, 3], signal 1 range [-1, 1] + LogOddsFusionQuery loq = + new LogOddsFusionQuery( + Arrays.asList(q1, q2), + 0.5f, + new float[] {0.6f, 0.4f}, + new float[] {-3f, -1f}, + new float[] {3f, 1f}); + + ScoreDoc[] hits = searcher.search(loq, 10).scoreDocs; + assertTrue("should have hits", hits.length > 0); + for (ScoreDoc hit : hits) { + assertTrue("score in (0,1): " + hit.score, hit.score > 0 && hit.score < 1); + } + } + + public void testNormalizedMaxScoreCorrectness() throws Exception { + Query q1 = bayesian(new TermQuery(new Term("body", "alpha"))); + Query q2 = bayesian(new TermQuery(new Term("body", "beta"))); + + LogOddsFusionQuery loq = + new LogOddsFusionQuery( + Arrays.asList(q1, q2), + 0.5f, + new float[] {0.7f, 0.3f}, + new float[] {-3f, -1f}, + new float[] {3f, 1f}); + CheckHits.checkTopScores(random(), loq, searcher); + } + /** L2-normalize a float vector. */ private static float[] normalize(float[] v) { double norm = 0;