tokens = new ArrayList<>(n);
+ for (int i = 0; i < n; i++) {
+ String token = parts[i].replaceAll("[^a-z0-9]", "");
+ if (token.isEmpty() == false) {
+ tokens.add(token);
+ }
+ }
+ return tokens.toArray(new String[0]);
+ }
+
+ private static float[] collectScores(IndexSearcher searcher, Query query, int maxDoc)
+ throws IOException {
+ int topN = Math.min(maxDoc, 10000);
+ TopDocs topDocs = searcher.search(query, topN);
+ float[] scores = new float[topDocs.scoreDocs.length];
+ for (int i = 0; i < topDocs.scoreDocs.length; i++) {
+ scores[i] = topDocs.scoreDocs[i].score;
+ }
+ return scores;
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/BayesianScoreQuery.java b/lucene/core/src/java/org/apache/lucene/search/BayesianScoreQuery.java
index fbb7af8fb6ca..87ce0350ff97 100644
--- a/lucene/core/src/java/org/apache/lucene/search/BayesianScoreQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/BayesianScoreQuery.java
@@ -30,7 +30,13 @@
* {@link LogOddsFusionQuery}.
*
* The alpha parameter controls the sigmoid steepness (score sensitivity), and beta controls the
- * midpoint (decision boundary). These can be set manually or estimated from the score distribution.
+ * midpoint (decision boundary). These can be set manually or estimated from the score distribution
+ * via {@link BayesianScoreEstimator}.
+ *
+ *
An optional base rate encodes the corpus-level prior probability that a random document is
+ * relevant to a random query. When set, the posterior is computed in log-odds space: {@code
+ * sigmoid(alpha * (score - beta) + logit(baseRate))}. This shifts scores down for rare-relevance
+ * corpora, improving calibration.
*
* @lucene.experimental
*/
@@ -39,15 +45,20 @@ public final class BayesianScoreQuery extends Query {
private final Query query;
private final float alpha;
private final float beta;
+ private final float baseRate;
+ private final float logitBaseRate;
/**
- * Creates a BayesianScoreQuery.
+ * Creates a BayesianScoreQuery with base rate.
*
* @param query the inner query whose scores will be transformed
* @param alpha sigmoid steepness (must be positive and finite)
* @param beta sigmoid midpoint (must be finite)
+ * @param baseRate corpus-level relevance prior in (0, 1), or 0 to disable. When positive, adds
+ * logit(baseRate) to the log-odds before sigmoid, shifting scores to account for the rarity
+ * of relevant documents.
*/
- public BayesianScoreQuery(Query query, float alpha, float beta) {
+ public BayesianScoreQuery(Query query, float alpha, float beta, float baseRate) {
this.query = Objects.requireNonNull(query);
if (Float.isFinite(alpha) == false || alpha <= 0) {
throw new IllegalArgumentException("alpha must be a positive finite value, got " + alpha);
@@ -55,8 +66,28 @@ public BayesianScoreQuery(Query query, float alpha, float beta) {
if (Float.isFinite(beta) == false) {
throw new IllegalArgumentException("beta must be a finite value, got " + beta);
}
+ if (baseRate < 0 || baseRate >= 1) {
+ throw new IllegalArgumentException("baseRate must be in [0, 1), got " + baseRate);
+ }
this.alpha = alpha;
this.beta = beta;
+ this.baseRate = baseRate;
+ if (baseRate > 0) {
+ this.logitBaseRate = (float) Math.log(baseRate / (1.0 - baseRate));
+ } else {
+ this.logitBaseRate = 0f;
+ }
+ }
+
+ /**
+ * Creates a BayesianScoreQuery without base rate.
+ *
+ * @param query the inner query whose scores will be transformed
+ * @param alpha sigmoid steepness (must be positive and finite)
+ * @param beta sigmoid midpoint (must be finite)
+ */
+ public BayesianScoreQuery(Query query, float alpha, float beta) {
+ this(query, alpha, beta, 0f);
}
/** Returns the wrapped query. */
@@ -74,6 +105,11 @@ public float getBeta() {
return beta;
}
+ /** Returns the base rate, or 0 if not set. */
+ public float getBaseRate() {
+ return baseRate;
+ }
+
static float sigmoid(float x) {
if (x >= 0) {
return (float) (1.0 / (1.0 + Math.exp(-x)));
@@ -100,7 +136,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return rewritten;
}
if (rewritten != query) {
- return new BayesianScoreQuery(rewritten, alpha, beta);
+ return new BayesianScoreQuery(rewritten, alpha, beta, baseRate);
}
return super.rewrite(indexSearcher);
}
@@ -112,7 +148,11 @@ public void visit(QueryVisitor visitor) {
@Override
public String toString(String field) {
- return "BayesianScore(" + query.toString(field) + ", alpha=" + alpha + ", beta=" + beta + ")";
+ String base = "BayesianScore(" + query.toString(field) + ", alpha=" + alpha + ", beta=" + beta;
+ if (baseRate > 0) {
+ base += ", baseRate=" + baseRate;
+ }
+ return base + ")";
}
@Override
@@ -123,7 +163,8 @@ public boolean equals(Object other) {
private boolean equalsTo(BayesianScoreQuery other) {
return query.equals(other.query)
&& Float.floatToIntBits(alpha) == Float.floatToIntBits(other.alpha)
- && Float.floatToIntBits(beta) == Float.floatToIntBits(other.beta);
+ && Float.floatToIntBits(beta) == Float.floatToIntBits(other.beta)
+ && Float.floatToIntBits(baseRate) == Float.floatToIntBits(other.baseRate);
}
@Override
@@ -132,6 +173,7 @@ public int hashCode() {
h = 31 * h + query.hashCode();
h = 31 * h + Float.floatToIntBits(alpha);
h = 31 * h + Float.floatToIntBits(beta);
+ h = 31 * h + Float.floatToIntBits(baseRate);
return h;
}
@@ -155,7 +197,18 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
return innerExpl;
}
float innerScore = innerExpl.getValue().floatValue();
- float transformed = sigmoid(alpha * (innerScore - beta));
+ float logOdds = alpha * (innerScore - beta) + logitBaseRate;
+ float transformed = sigmoid(logOdds);
+ if (baseRate > 0) {
+ return Explanation.match(
+ transformed,
+ "sigmoid calibration with base rate, computed as"
+ + " sigmoid(alpha * (score - beta) + logit(baseRate)) from:",
+ innerExpl,
+ Explanation.match(alpha, "alpha, sigmoid steepness"),
+ Explanation.match(beta, "beta, sigmoid midpoint"),
+ Explanation.match(baseRate, "baseRate, corpus-level relevance prior"));
+ }
return Explanation.match(
transformed,
"sigmoid calibration, computed as sigmoid(alpha * (score - beta)) from:",
@@ -204,7 +257,7 @@ private class BayesianScoreScorer extends FilterScorer {
@Override
public float score() throws IOException {
float innerScore = in.score();
- return sigmoid(alpha * (innerScore - beta));
+ return sigmoid(alpha * (innerScore - beta) + logitBaseRate);
}
@Override
@@ -216,19 +269,19 @@ public int advanceShallow(int target) throws IOException {
public float getMaxScore(int upTo) throws IOException {
float innerMax = in.getMaxScore(upTo);
// sigmoid is monotone, so max(sigmoid(f(x))) = sigmoid(max(f(x)))
- return sigmoid(alpha * (innerMax - beta));
+ return sigmoid(alpha * (innerMax - beta) + logitBaseRate);
}
@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
// Invert the sigmoid to get the minimum inner score needed:
- // minScore = sigmoid(alpha * (innerScore - beta))
- // => alpha * (innerScore - beta) = logit(minScore)
- // => innerScore = logit(minScore) / alpha + beta
+ // minScore = sigmoid(alpha * (innerScore - beta) + logitBaseRate)
+ // => alpha * (innerScore - beta) + logitBaseRate = logit(minScore)
+ // => innerScore = (logit(minScore) - logitBaseRate) / alpha + beta
if (minScore > 0f && minScore < 1f) {
float clamped = Math.max(1e-7f, Math.min(1f - 1e-7f, minScore));
float logitMin = (float) Math.log(clamped / (1f - clamped));
- float innerMin = logitMin / alpha + beta;
+ float innerMin = (logitMin - logitBaseRate) / alpha + beta;
in.setMinCompetitiveScore(Math.max(0f, innerMin));
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionQuery.java b/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionQuery.java
index ad1d26179ccf..5fb575d5f288 100644
--- a/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionQuery.java
@@ -18,6 +18,7 @@
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
@@ -43,6 +44,11 @@
*
The alpha parameter controls the confidence scaling exponent. The default alpha=0.5 implements
* the sqrt(n) scaling law from "From Bayesian Inference to Neural Computation".
*
+ *
Optional per-signal weights enable weighted Log-OP (Logarithmic Opinion Pooling) where each
+ * signal's log-odds contribution is scaled by its reliability weight. Weights must be non-negative
+ * and sum to 1. When weights are provided, the scoring formula becomes: {@code sigmoid(n^alpha *
+ * sum(w_i * softplus(logit(p_i))))} instead of the uniform mean.
+ *
* @see LogOddsFusionScorer
* @lucene.experimental
*/
@@ -51,31 +57,101 @@ public final class LogOddsFusionQuery extends Query implements Iterable {
private final Multiset clauses = new Multiset<>();
private final List orderedClauses;
private final float alpha;
+ private final float[] signalWeights;
+ private final float[] logitMin;
+ private final float[] logitMax;
/**
- * Creates a new LogOddsFusionQuery.
+ * Creates a new LogOddsFusionQuery with per-signal weights and optional logit normalization.
*
* @param clauses the sub-queries to combine
* @param alpha confidence scaling exponent (0.5 = sqrt(n) law)
- * @throws IllegalArgumentException if alpha is not in [0, 1]
+ * @param weights per-signal weights (must be non-negative, finite, and sum to 1.0), or null for
+ * uniform weighting
+ * @param logitMin per-signal logit lower bounds for normalization, or null to use softplus gating
+ * @param logitMax per-signal logit upper bounds for normalization, or null to use softplus gating
+ * @throws IllegalArgumentException if alpha is not in [0, 1], or weights/bounds are invalid
*/
- public LogOddsFusionQuery(Collection extends Query> clauses, float alpha) {
+ public LogOddsFusionQuery(
+ Collection extends Query> clauses,
+ float alpha,
+ float[] weights,
+ float[] logitMin,
+ float[] logitMax) {
Objects.requireNonNull(clauses, "Collection of Queries must not be null");
if (Float.isNaN(alpha) || alpha < 0 || alpha > 1) {
throw new IllegalArgumentException("alpha must be in [0, 1], got " + alpha);
}
+ if (weights != null) {
+ if (weights.length != clauses.size()) {
+ throw new IllegalArgumentException(
+ "weights length " + weights.length + " must equal clauses size " + clauses.size());
+ }
+ float sum = 0;
+ for (float w : weights) {
+ if (Float.isFinite(w) == false || w < 0) {
+ throw new IllegalArgumentException("weights must be non-negative and finite, got " + w);
+ }
+ sum += w;
+ }
+ if (Math.abs(sum - 1.0f) > 1e-3f) {
+ throw new IllegalArgumentException("weights must sum to 1.0, got " + sum);
+ }
+ this.signalWeights = weights.clone();
+ } else {
+ this.signalWeights = null;
+ }
+ if (logitMin != null && logitMax != null) {
+ if (logitMin.length != clauses.size()) {
+ throw new IllegalArgumentException(
+ "logitMin length " + logitMin.length + " must equal clauses size " + clauses.size());
+ }
+ if (logitMax.length != clauses.size()) {
+ throw new IllegalArgumentException(
+ "logitMax length " + logitMax.length + " must equal clauses size " + clauses.size());
+ }
+ this.logitMin = logitMin.clone();
+ this.logitMax = logitMax.clone();
+ } else {
+ this.logitMin = null;
+ this.logitMax = null;
+ }
this.alpha = alpha;
this.clauses.addAll(clauses);
this.orderedClauses = new ArrayList<>(clauses);
}
/**
- * Creates a new LogOddsFusionQuery with default alpha=0.5 (sqrt(n) scaling law).
+ * Creates a new LogOddsFusionQuery with per-signal weights (softplus gating, no normalization).
+ *
+ * @param clauses the sub-queries to combine
+ * @param alpha confidence scaling exponent (0.5 = sqrt(n) law)
+ * @param weights per-signal weights, or null for uniform weighting
+ * @throws IllegalArgumentException if alpha is not in [0, 1], or weights are invalid
+ */
+ public LogOddsFusionQuery(Collection extends Query> clauses, float alpha, float[] weights) {
+ this(clauses, alpha, weights, null, null);
+ }
+
+ /**
+ * Creates a new LogOddsFusionQuery with uniform weighting and softplus gating.
+ *
+ * @param clauses the sub-queries to combine
+ * @param alpha confidence scaling exponent (0.5 = sqrt(n) law)
+ * @throws IllegalArgumentException if alpha is not in [0, 1]
+ */
+ public LogOddsFusionQuery(Collection extends Query> clauses, float alpha) {
+ this(clauses, alpha, null, null, null);
+ }
+
+ /**
+ * Creates a new LogOddsFusionQuery with default alpha=0.5, uniform weighting, and softplus
+ * gating.
*
* @param clauses the sub-queries to combine
*/
public LogOddsFusionQuery(Collection extends Query> clauses) {
- this(clauses, 0.5f);
+ this(clauses, 0.5f, null, null, null);
}
@Override
@@ -93,6 +169,16 @@ public float getAlpha() {
return alpha;
}
+ /**
+ * Returns a copy of the per-signal weights, or null if uniform weighting is used.
+ *
+ * When non-null, the i-th element is the weight for the i-th clause in the order returned by
+ * {@link #getClauses()}.
+ */
+ public float[] getWeights() {
+ return signalWeights != null ? signalWeights.clone() : null;
+ }
+
/** Weight for LogOddsFusionQuery. */
protected class LogOddsFusionWeight extends Weight {
@@ -102,7 +188,7 @@ protected class LogOddsFusionWeight extends Weight {
public LogOddsFusionWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
super(LogOddsFusionQuery.this);
- for (Query clauseQuery : clauses) {
+ for (Query clauseQuery : orderedClauses) {
weights.add(searcher.createWeight(clauseQuery, scoreMode, boost));
}
this.scoreMode = scoreMode;
@@ -123,10 +209,21 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException {
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
List scorerSuppliers = new ArrayList<>();
- for (Weight w : weights) {
- ScorerSupplier ss = w.scorerSupplier(context);
+ List activeWeightsList = signalWeights != null ? new ArrayList<>() : null;
+ List activeLogitMinList = logitMin != null ? new ArrayList<>() : null;
+ List activeLogitMaxList = logitMax != null ? new ArrayList<>() : null;
+
+ for (int i = 0; i < weights.size(); i++) {
+ ScorerSupplier ss = weights.get(i).scorerSupplier(context);
if (ss != null) {
scorerSuppliers.add(ss);
+ if (activeWeightsList != null) {
+ activeWeightsList.add(signalWeights[i]);
+ }
+ if (activeLogitMinList != null) {
+ activeLogitMinList.add(logitMin[i]);
+ activeLogitMaxList.add(logitMax[i]);
+ }
}
}
@@ -136,6 +233,10 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
return scorerSuppliers.get(0);
} else {
final int totalClauses = clauses.size();
+ final float[] activeWeights = toFloatArray(activeWeightsList);
+ final float[] activeMin = toFloatArray(activeLogitMinList);
+ final float[] activeMax = toFloatArray(activeLogitMaxList);
+
return new ScorerSupplier() {
private long cost = -1;
@@ -146,7 +247,15 @@ public Scorer get(long leadCost) throws IOException {
for (ScorerSupplier ss : scorerSuppliers) {
scorers.add(ss.get(leadCost));
}
- return new LogOddsFusionScorer(scorers, totalClauses, alpha, scoreMode, leadCost);
+ return new LogOddsFusionScorer(
+ scorers,
+ totalClauses,
+ alpha,
+ activeWeights,
+ activeMin,
+ activeMax,
+ scoreMode,
+ leadCost);
}
@Override
@@ -191,27 +300,44 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
double logitSum = 0;
int totalClauses = weights.size();
- for (Weight wt : weights) {
- Explanation e = wt.explain(context, doc);
+ for (int i = 0; i < weights.size(); i++) {
+ Explanation e = weights.get(i).explain(context, doc);
if (e.isMatch()) {
match = true;
subsOnMatch.add(e);
float subScore = e.getValue().floatValue();
- logitSum += LogOddsFusionScorer.softplus(LogOddsFusionScorer.logit(subScore));
+ float rawLogit = LogOddsFusionScorer.logit(subScore);
+ float gated;
+ if (logitMin != null) {
+ float range = logitMax[i] - logitMin[i];
+ gated = range > 0 ? Math.clamp((rawLogit - logitMin[i]) / range, 0f, 1f) : 0.5f;
+ } else {
+ gated = LogOddsFusionScorer.softplus(rawLogit);
+ }
+ if (signalWeights != null) {
+ logitSum += signalWeights[i] * gated;
+ } else {
+ logitSum += gated;
+ }
} else if (match == false) {
subsOnNoMatch.add(e);
}
}
if (match) {
- // Non-matching contribute logit(0.5) = 0
- float meanLogit = (float) (logitSum / totalClauses);
float scalingFactor = (float) Math.pow(totalClauses, alpha);
- float scaledLogit = meanLogit * scalingFactor;
+ float scaledLogit = 0f;
+ String description;
+ if (signalWeights != null) {
+ scaledLogit = (float) logitSum * scalingFactor;
+ description =
+ "weighted log-odds fusion, computed as sigmoid(weightedLogit * n^alpha) from:";
+ } else {
+ scaledLogit = (float) (logitSum / totalClauses) * scalingFactor;
+ description = "log-odds fusion, computed as sigmoid(meanLogit * n^alpha) from:";
+ }
float score = LogOddsFusionScorer.sigmoid(scaledLogit);
-
- return Explanation.match(
- score, "log-odds fusion, computed as sigmoid(meanLogit * n^alpha) from:", subsOnMatch);
+ return Explanation.match(score, description, subsOnMatch);
} else {
return Explanation.noMatch("No matching clause", subsOnNoMatch);
}
@@ -231,19 +357,19 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
}
if (clauses.size() == 1) {
- return clauses.iterator().next();
+ return orderedClauses.get(0);
}
boolean actuallyRewritten = false;
List rewrittenClauses = new ArrayList<>();
- for (Query sub : clauses) {
+ for (Query sub : orderedClauses) {
Query rewrittenSub = sub.rewrite(indexSearcher);
actuallyRewritten |= rewrittenSub != sub;
rewrittenClauses.add(rewrittenSub);
}
if (actuallyRewritten) {
- return new LogOddsFusionQuery(rewrittenClauses, alpha);
+ return new LogOddsFusionQuery(rewrittenClauses, alpha, signalWeights, logitMin, logitMax);
}
return super.rewrite(indexSearcher);
@@ -259,15 +385,20 @@ public void visit(QueryVisitor visitor) {
@Override
public String toString(String field) {
- return this.orderedClauses.stream()
- .map(
- subquery -> {
- if (subquery instanceof BooleanQuery) {
- return "(" + subquery.toString(field) + ")";
- }
- return subquery.toString(field);
- })
- .collect(Collectors.joining(" & ", "LogOdds(", ")^" + alpha));
+ String base =
+ this.orderedClauses.stream()
+ .map(
+ subquery -> {
+ if (subquery instanceof BooleanQuery) {
+ return "(" + subquery.toString(field) + ")";
+ }
+ return subquery.toString(field);
+ })
+ .collect(Collectors.joining(" & ", "LogOdds(", ")^" + alpha));
+ if (signalWeights != null) {
+ return base + " w=" + Arrays.toString(signalWeights);
+ }
+ return base;
}
@Override
@@ -276,7 +407,11 @@ public boolean equals(Object other) {
}
private boolean equalsTo(LogOddsFusionQuery other) {
- return alpha == other.alpha && Objects.equals(clauses, other.clauses);
+ return alpha == other.alpha
+ && Objects.equals(clauses, other.clauses)
+ && Arrays.equals(signalWeights, other.signalWeights)
+ && Arrays.equals(logitMin, other.logitMin)
+ && Arrays.equals(logitMax, other.logitMax);
}
@Override
@@ -284,6 +419,20 @@ public int hashCode() {
int h = classHash();
h = 31 * h + Float.floatToIntBits(alpha);
h = 31 * h + Objects.hashCode(clauses);
+ h = 31 * h + Arrays.hashCode(signalWeights);
+ h = 31 * h + Arrays.hashCode(logitMin);
+ h = 31 * h + Arrays.hashCode(logitMax);
return h;
}
+
+ private static float[] toFloatArray(List list) {
+ if (list == null) {
+ return null;
+ }
+ float[] result = new float[list.size()];
+ for (int i = 0; i < list.size(); i++) {
+ result[i] = list.get(i);
+ }
+ return result;
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionScorer.java b/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionScorer.java
index cb8ee8aa46de..bfbc401bc42f 100644
--- a/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/LogOddsFusionScorer.java
@@ -17,6 +17,7 @@
package org.apache.lucene.search;
import java.io.IOException;
+import java.util.IdentityHashMap;
import java.util.List;
/**
@@ -50,6 +51,10 @@ final class LogOddsFusionScorer extends DisjunctionScorer {
private final List subScorers;
private final int totalClauses;
private final float scalingFactor;
+ private final float[] signalWeights;
+ private final float[] logitMin;
+ private final float[] logitMax;
+ private final IdentityHashMap scorerIndexMap;
private final DisjunctionScoreBlockBoundaryPropagator disjunctionBlockPropagator;
@@ -59,16 +64,42 @@ final class LogOddsFusionScorer extends DisjunctionScorer {
* @param subScorers the sub scorers to combine
* @param totalClauses the total number of clauses (including non-matching)
* @param alpha confidence scaling exponent (0.5 = sqrt(n) law)
+ * @param signalWeights per-signal weights parallel to subScorers (null for uniform weighting).
+ * When provided, the scoring formula uses weighted sum instead of mean. Weights must be
+ * non-negative and should sum to 1.
+ * @param logitMin per-signal logit lower bounds for normalization (null to use softplus gating).
+ * When provided together with logitMax, logit values are normalized to [0, 1] instead of
+ * applying softplus. This ensures non-negative contributions while preserving learned signal
+ * scale calibration.
+ * @param logitMax per-signal logit upper bounds for normalization (null to use softplus gating)
* @param scoreMode the score mode
* @param leadCost the lead cost for iteration
*/
LogOddsFusionScorer(
- List subScorers, int totalClauses, float alpha, ScoreMode scoreMode, long leadCost)
+ List subScorers,
+ int totalClauses,
+ float alpha,
+ float[] signalWeights,
+ float[] logitMin,
+ float[] logitMax,
+ ScoreMode scoreMode,
+ long leadCost)
throws IOException {
super(subScorers, scoreMode, leadCost);
this.subScorers = subScorers;
this.totalClauses = totalClauses;
this.scalingFactor = (float) Math.pow(totalClauses, alpha);
+ this.signalWeights = signalWeights;
+ this.logitMin = logitMin;
+ this.logitMax = logitMax;
+ if (signalWeights != null) {
+ this.scorerIndexMap = new IdentityHashMap<>(subScorers.size());
+ for (int i = 0; i < subScorers.size(); i++) {
+ this.scorerIndexMap.put(subScorers.get(i), i);
+ }
+ } else {
+ this.scorerIndexMap = null;
+ }
if (scoreMode == ScoreMode.TOP_SCORES) {
this.disjunctionBlockPropagator = new DisjunctionScoreBlockBoundaryPropagator(subScorers);
} else {
@@ -108,19 +139,40 @@ static float softplus(float x) {
return (float) Math.log1p(Math.exp(x));
}
+ /** Applies gating to a logit value: normalization if bounds are set, softplus otherwise. */
+ private float gateLogit(float rawLogit, int signalIndex) {
+ if (logitMin != null) {
+ float range = logitMax[signalIndex] - logitMin[signalIndex];
+ if (range > 0) {
+ return Math.clamp((rawLogit - logitMin[signalIndex]) / range, 0f, 1f);
+ }
+ return 0.5f;
+ }
+ return softplus(rawLogit);
+ }
+
@Override
protected float score(DisiWrapper topList) throws IOException {
double logitSum = 0;
for (DisiWrapper w = topList; w != null; w = w.next) {
float subScore = w.scorable.score();
- logitSum += softplus(logit(subScore));
+ int idx = scorerIndexMap != null ? scorerIndexMap.get(w.scorer) : -1;
+ float gated = gateLogit(logit(subScore), idx >= 0 ? idx : 0);
+ if (scorerIndexMap != null) {
+ logitSum += signalWeights[idx] * gated;
+ } else {
+ logitSum += gated;
+ }
+ }
+ // Non-matching sub-scorers contribute 0.
+ // With weights: sum(w_i * gated_i) already accounts for the 1/n factor.
+ // Without weights: divide by totalClauses to compute the mean.
+ float scaledLogit = 0f;
+ if (signalWeights != null) {
+ scaledLogit = (float) logitSum * scalingFactor;
+ } else {
+ scaledLogit = (float) (logitSum / totalClauses) * scalingFactor;
}
- // Non-matching sub-scorers contribute logit(0.5) = 0, softplus(0) = log(2) ~ 0.693.
- // But we do NOT add this for non-matching scorers: they contribute 0, not softplus(0).
- // This is the key distinction: a match always contributes softplus(logit) > 0,
- // while a non-match contributes exactly 0.
- float meanLogit = (float) (logitSum / totalClauses);
- float scaledLogit = meanLogit * scalingFactor;
return sigmoid(scaledLogit);
}
@@ -134,17 +186,27 @@ public int advanceShallow(int target) throws IOException {
@Override
public float getMaxScore(int upTo) throws IOException {
+ // Safe upper bound: gateLogit is monotone in p (both softplus and normalize are monotone),
+ // weights are non-negative, sum of upper bounds >= sum of actuals, and sigmoid is monotone.
double maxLogitSum = 0;
- for (Scorer scorer : subScorers) {
+ for (int i = 0; i < subScorers.size(); i++) {
+ Scorer scorer = subScorers.get(i);
if (scorer.docID() <= upTo) {
float maxSubScore = scorer.getMaxScore(upTo);
- maxLogitSum += softplus(logit(maxSubScore));
+ float gated = gateLogit(logit(maxSubScore), i);
+ if (signalWeights != null) {
+ maxLogitSum += signalWeights[i] * gated;
+ } else {
+ maxLogitSum += gated;
+ }
}
}
- // Safe upper bound: softplus(logit) is monotone in p, sum of upper bounds >= sum of actuals,
- // sigmoid is monotone
- float meanLogit = (float) (maxLogitSum / totalClauses);
- float scaledLogit = meanLogit * scalingFactor;
+ float scaledLogit = 0f;
+ if (signalWeights != null) {
+ scaledLogit = (float) maxLogitSum * scalingFactor;
+ } else {
+ scaledLogit = (float) (maxLogitSum / totalClauses) * scalingFactor;
+ }
return sigmoid(scaledLogit);
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBayesianScoreQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestBayesianScoreQuery.java
index 6cc086ceb300..de77c732f8f6 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestBayesianScoreQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestBayesianScoreQuery.java
@@ -257,4 +257,130 @@ public void testDifferentAlphaBeta() throws Exception {
// Same ranking order (sigmoid is monotone regardless of alpha/beta)
assertEquals("same top doc", gentleHits[0].doc, steepHits[0].doc);
}
+
+ // ---- Base rate tests ----
+
+ public void testBaseRateLowersScores() throws Exception {
+ Query inner = new TermQuery(new Term("body", "alpha"));
+
+ BayesianScoreQuery noBaseRate = new BayesianScoreQuery(inner, 0.5f, 3.0f);
+ BayesianScoreQuery withBaseRate = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f);
+
+ ScoreDoc[] hitsNo = searcher.search(noBaseRate, 10).scoreDocs;
+ ScoreDoc[] hitsBR = searcher.search(withBaseRate, 10).scoreDocs;
+
+ assertTrue("both should have hits", hitsNo.length > 0 && hitsBR.length > 0);
+
+ // Same ranking order (baseRate is a constant shift in log-odds, preserves monotonicity)
+ assertEquals("same top doc", hitsNo[0].doc, hitsBR[0].doc);
+
+ // Base rate < 0.5 adds negative logit, so scores should be lower
+ for (int i = 0; i < Math.min(hitsNo.length, hitsBR.length); i++) {
+ if (hitsNo[i].doc == hitsBR[i].doc) {
+ assertTrue(
+ "base rate 0.01 should lower score: " + hitsBR[i].score + " < " + hitsNo[i].score,
+ hitsBR[i].score < hitsNo[i].score);
+ }
+ }
+ }
+
+ public void testBaseRateScoresInRange() throws Exception {
+ Query inner = new TermQuery(new Term("body", "alpha"));
+ BayesianScoreQuery bsq = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.001f);
+
+ ScoreDoc[] hits = searcher.search(bsq, 10).scoreDocs;
+ for (ScoreDoc hit : hits) {
+ assertTrue("score in (0,1): " + hit.score, hit.score > 0 && hit.score < 1);
+ }
+ }
+
+ public void testBaseRateMaxScoreCorrectness() throws Exception {
+ Query inner = new TermQuery(new Term("body", "alpha"));
+ BayesianScoreQuery bsq = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f);
+ CheckHits.checkTopScores(random(), bsq, searcher);
+ }
+
+ public void testBaseRateExplanation() throws Exception {
+ Query inner = new TermQuery(new Term("body", "alpha"));
+ BayesianScoreQuery bsq = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.05f);
+
+ Weight w = searcher.createWeight(searcher.rewrite(bsq), ScoreMode.COMPLETE, 1);
+ LeafReaderContext ctx = searcher.getIndexReader().leaves().get(0);
+ Explanation expl = w.explain(ctx, 0);
+ assertTrue("should match", expl.isMatch());
+ assertTrue("should mention base rate", expl.getDescription().contains("base rate"));
+ }
+
+ public void testBaseRateQueryUtils() throws Exception {
+ Query inner = new TermQuery(new Term("body", "alpha"));
+ BayesianScoreQuery bsq = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f);
+ QueryUtils.check(random(), bsq, searcher);
+ }
+
+ public void testBaseRateEqualsAndHashCode() {
+ Query inner = new TermQuery(new Term("body", "alpha"));
+ BayesianScoreQuery a = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f);
+ BayesianScoreQuery b = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.01f);
+ BayesianScoreQuery c = new BayesianScoreQuery(inner, 0.5f, 3.0f, 0.05f);
+ BayesianScoreQuery d = new BayesianScoreQuery(inner, 0.5f, 3.0f);
+
+ assertEquals(a, b);
+ assertEquals(a.hashCode(), b.hashCode());
+ assertNotEquals(a, c);
+ assertNotEquals(a, d);
+ }
+
+ public void testIllegalBaseRate() {
+ Query inner = new TermQuery(new Term("body", "alpha"));
+ expectThrows(
+ IllegalArgumentException.class, () -> new BayesianScoreQuery(inner, 0.5f, 3.0f, -0.1f));
+ expectThrows(
+ IllegalArgumentException.class, () -> new BayesianScoreQuery(inner, 0.5f, 3.0f, 1.0f));
+ }
+
+ // ---- Auto-estimation tests ----
+
+ public void testEstimatorReturnsFiniteValues() throws Exception {
+ BayesianScoreEstimator.Parameters params = BayesianScoreEstimator.estimate(searcher, "body");
+
+ assertTrue(
+ "alpha should be positive and finite",
+ params.alpha() > 0 && Float.isFinite(params.alpha()));
+ assertTrue("beta should be finite", Float.isFinite(params.beta()));
+ assertTrue("baseRate in (0, 0.5]", params.baseRate() > 0 && params.baseRate() <= 0.5f);
+ }
+
+ public void testEstimatedParametersProduceValidScores() throws Exception {
+ BayesianScoreEstimator.Parameters params = BayesianScoreEstimator.estimate(searcher, "body");
+
+ Query inner = new TermQuery(new Term("body", "alpha"));
+ BayesianScoreQuery bsq =
+ new BayesianScoreQuery(inner, params.alpha(), params.beta(), params.baseRate());
+
+ ScoreDoc[] hits = searcher.search(bsq, 10).scoreDocs;
+ assertTrue("should have hits", hits.length > 0);
+ for (ScoreDoc hit : hits) {
+ assertTrue("score in (0,1): " + hit.score, hit.score > 0 && hit.score < 1);
+ }
+ }
+
+ public void testEstimatedMaxScoreCorrectness() throws Exception {
+ BayesianScoreEstimator.Parameters params = BayesianScoreEstimator.estimate(searcher, "body");
+
+ Query inner = new TermQuery(new Term("body", "alpha"));
+ BayesianScoreQuery bsq =
+ new BayesianScoreQuery(inner, params.alpha(), params.beta(), params.baseRate());
+ CheckHits.checkTopScores(random(), bsq, searcher);
+ }
+
+ public void testEstimatorReproducibleWithSeed() throws Exception {
+ BayesianScoreEstimator.Parameters p1 =
+ BayesianScoreEstimator.estimate(searcher, "body", 20, 3, 123);
+ BayesianScoreEstimator.Parameters p2 =
+ BayesianScoreEstimator.estimate(searcher, "body", 20, 3, 123);
+
+ assertEquals("same seed should produce same alpha", p1.alpha(), p2.alpha(), 0f);
+ assertEquals("same seed should produce same beta", p1.beta(), p2.beta(), 0f);
+ assertEquals("same seed should produce same baseRate", p1.baseRate(), p2.baseRate(), 0f);
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLogOddsFusionQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestLogOddsFusionQuery.java
index d3f1f9c653e2..e5d1b59eff90 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestLogOddsFusionQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestLogOddsFusionQuery.java
@@ -18,6 +18,7 @@
import java.util.Arrays;
import java.util.Collections;
+import java.util.List;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnFloatVectorField;
@@ -668,6 +669,210 @@ public void testVectorNot() throws Exception {
}
}
+ // ---- Weighted fusion tests ----
+
+ public void testWeightedFusion() throws Exception {
+ Query q1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query q2 = bayesian(new TermQuery(new Term("body", "beta")));
+
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f});
+ QueryUtils.check(random(), loq, searcher);
+
+ ScoreDoc[] hits = searcher.search(loq, 10).scoreDocs;
+ assertTrue("should have at least 1 hit", hits.length >= 1);
+ for (ScoreDoc hit : hits) {
+ assertTrue("score should be > 0", hit.score > 0);
+ assertTrue("score should be < 1", hit.score < 1);
+ }
+ }
+
+ public void testWeightedFusionAffectsRanking() throws Exception {
+ Query qAlpha = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query qBeta = bayesian(new TermQuery(new Term("body", "beta")));
+
+ // doc0: alpha beta gamma -> matches both
+ // doc1: alpha gamma delta -> matches alpha only
+ // doc2: beta gamma delta -> matches beta only
+
+ // Heavy weight on alpha: doc1 (alpha-only) should rank above doc2 (beta-only)
+ LogOddsFusionQuery alphaHeavy =
+ new LogOddsFusionQuery(Arrays.asList(qAlpha, qBeta), 0.5f, new float[] {0.9f, 0.1f});
+ ScoreDoc[] hitsAlpha = searcher.search(alphaHeavy, 10).scoreDocs;
+
+ // Heavy weight on beta: doc2 (beta-only) should rank above doc1 (alpha-only)
+ LogOddsFusionQuery betaHeavy =
+ new LogOddsFusionQuery(Arrays.asList(qAlpha, qBeta), 0.5f, new float[] {0.1f, 0.9f});
+ ScoreDoc[] hitsBeta = searcher.search(betaHeavy, 10).scoreDocs;
+
+ // Find scores of doc1 and doc2 in each case
+ float doc1AlphaHeavy = 0, doc2AlphaHeavy = 0;
+ float doc1BetaHeavy = 0, doc2BetaHeavy = 0;
+ for (ScoreDoc hit : hitsAlpha) {
+ if (hit.doc == 1) doc1AlphaHeavy = hit.score;
+ if (hit.doc == 2) doc2AlphaHeavy = hit.score;
+ }
+ for (ScoreDoc hit : hitsBeta) {
+ if (hit.doc == 1) doc1BetaHeavy = hit.score;
+ if (hit.doc == 2) doc2BetaHeavy = hit.score;
+ }
+
+ assertTrue("alpha-heavy: doc1 (alpha) > doc2 (beta)", doc1AlphaHeavy > doc2AlphaHeavy);
+ assertTrue("beta-heavy: doc2 (beta) > doc1 (alpha)", doc2BetaHeavy > doc1BetaHeavy);
+ }
+
+ public void testWeightedExplanation() throws Exception {
+ Query q1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query q2 = bayesian(new TermQuery(new Term("body", "beta")));
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.6f, 0.4f});
+
+ Weight w = searcher.createWeight(searcher.rewrite(loq), ScoreMode.COMPLETE, 1);
+ LeafReaderContext context = searcher.getIndexReader().leaves().get(0);
+
+ Explanation expl = w.explain(context, 0);
+ assertTrue("doc0 should match", expl.isMatch());
+ assertTrue("should mention weighted", expl.getDescription().contains("weighted"));
+ }
+
+ public void testWeightedEqualsAndHashCode() {
+ Query q1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query q2 = bayesian(new TermQuery(new Term("body", "beta")));
+
+ LogOddsFusionQuery a =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f});
+ LogOddsFusionQuery b =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f});
+ LogOddsFusionQuery c =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.3f, 0.7f});
+ LogOddsFusionQuery d = new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f);
+
+ assertEquals(a, b);
+ assertEquals(a.hashCode(), b.hashCode());
+ assertNotEquals(a, c);
+ assertNotEquals(a, d);
+ }
+
+ public void testWeightedMaxScoreCorrectness() throws Exception {
+ Query q1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query q2 = bayesian(new TermQuery(new Term("body", "beta")));
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f});
+ CheckHits.checkTopScores(random(), loq, searcher);
+ }
+
+ public void testWeightedToString() {
+ Query q1 = new TermQuery(new Term("body", "alpha"));
+ Query q2 = new TermQuery(new Term("body", "beta"));
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f});
+ String str = loq.toString("body");
+ assertTrue("should contain LogOdds", str.contains("LogOdds"));
+ assertTrue("should contain w=", str.contains("w="));
+ assertTrue("should contain 0.7", str.contains("0.7"));
+ }
+
+ public void testWeightedRewrite() throws Exception {
+ Query sub1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query sub2 = bayesian(new TermQuery(new Term("body", "beta")));
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(Arrays.asList(sub1, sub2), 0.5f, new float[] {0.6f, 0.4f});
+
+ Query rewritten = searcher.rewrite(loq);
+ assertTrue(
+ "weighted rewrite should produce LogOddsFusionQuery",
+ rewritten instanceof LogOddsFusionQuery);
+ LogOddsFusionQuery rewrittenLoq = (LogOddsFusionQuery) rewritten;
+ assertNotNull("weights should be preserved", rewrittenLoq.getWeights());
+ assertEquals(0.6f, rewrittenLoq.getWeights()[0], 1e-6f);
+ assertEquals(0.4f, rewrittenLoq.getWeights()[1], 1e-6f);
+ }
+
+ public void testIllegalWeights() {
+ Query q1 = new TermQuery(new Term("body", "alpha"));
+ Query q2 = new TermQuery(new Term("body", "beta"));
+ List clauses = Arrays.asList(q1, q2);
+
+ // Wrong length
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new LogOddsFusionQuery(clauses, 0.5f, new float[] {1.0f}));
+
+ // Negative weight
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new LogOddsFusionQuery(clauses, 0.5f, new float[] {-0.1f, 1.1f}));
+
+ // NaN weight
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new LogOddsFusionQuery(clauses, 0.5f, new float[] {Float.NaN, 0.5f}));
+
+ // Sum != 1.0
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new LogOddsFusionQuery(clauses, 0.5f, new float[] {0.3f, 0.3f}));
+ }
+
+ public void testWeightedQueryUtils() throws Exception {
+ Query q1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query q2 = bayesian(new TermQuery(new Term("body", "beta")));
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2), 0.5f, new float[] {0.7f, 0.3f});
+ QueryUtils.check(random(), loq, searcher);
+ }
+
+ public void testWeightedThreeWayCombination() throws Exception {
+ Query q1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query q2 = bayesian(new TermQuery(new Term("body", "beta")));
+ Query q3 = bayesian(new TermQuery(new Term("body", "gamma")));
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(Arrays.asList(q1, q2, q3), 0.5f, new float[] {0.5f, 0.3f, 0.2f});
+ QueryUtils.check(random(), loq, searcher);
+
+ ScoreDoc[] hits = searcher.search(loq, 10).scoreDocs;
+ assertTrue("should have hits", hits.length > 0);
+ for (ScoreDoc hit : hits) {
+ assertTrue("score in (0,1)", hit.score > 0 && hit.score < 1);
+ }
+ }
+
+ // ---- Logit normalization tests ----
+
+ public void testNormalizedFusionProducesValidScores() throws Exception {
+ Query q1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query q2 = bayesian(new TermQuery(new Term("body", "beta")));
+
+ // logit bounds: signal 0 range [-3, 3], signal 1 range [-1, 1]
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(
+ Arrays.asList(q1, q2),
+ 0.5f,
+ new float[] {0.6f, 0.4f},
+ new float[] {-3f, -1f},
+ new float[] {3f, 1f});
+
+ ScoreDoc[] hits = searcher.search(loq, 10).scoreDocs;
+ assertTrue("should have hits", hits.length > 0);
+ for (ScoreDoc hit : hits) {
+ assertTrue("score in (0,1): " + hit.score, hit.score > 0 && hit.score < 1);
+ }
+ }
+
+ public void testNormalizedMaxScoreCorrectness() throws Exception {
+ Query q1 = bayesian(new TermQuery(new Term("body", "alpha")));
+ Query q2 = bayesian(new TermQuery(new Term("body", "beta")));
+
+ LogOddsFusionQuery loq =
+ new LogOddsFusionQuery(
+ Arrays.asList(q1, q2),
+ 0.5f,
+ new float[] {0.7f, 0.3f},
+ new float[] {-3f, -1f},
+ new float[] {3f, 1f});
+ CheckHits.checkTopScores(random(), loq, searcher);
+ }
+
/** L2-normalize a float vector. */
private static float[] normalize(float[] v) {
double norm = 0;