diff --git a/lucene/benchmark-jmh/jmh-table.py b/lucene/benchmark-jmh/jmh-table.py new file mode 100755 index 000000000000..1915672ee43f --- /dev/null +++ b/lucene/benchmark-jmh/jmh-table.py @@ -0,0 +1,771 @@ +#!/usr/bin/env python3 +"""Parse JMH JSON output from stdin, produce an interactive HTML table on stdout. + +Supports both JSON (-rf json) and plain text JMH output. +With JSON input, clicking a cell shows a histogram of the raw iteration samples +and the benchmark method source code. + +Usage: + # JSON (recommended – enables histograms + source): + java --module-path ... --module org.apache.lucene.benchmark.jmh ScoreDocSortBenchmark \ + -rf json -rff results.json \ + && python3 jmh-table.py [BenchmarkSource.java] < results.json > results.html + + # Plain text (no histograms): + java --module-path ... --module org.apache.lucene.benchmark.jmh ScoreDocSortBenchmark \ + | python3 jmh-table.py > results.html + + The optional positional argument is the path to the Java source file containing + the @Benchmark methods. If provided, clicking a cell also shows the method source. +""" + +import sys +import re +import json +import html +import math + + +def parse_jmh_text(text): + """Parse plain-text JMH output.""" + entries = [] + for line in text.splitlines(): + m = re.match( + r'\S+\.(\S+)\s+' + r'(\S+)\s+' + r'\S+\s+' + r'\d+\s+' + r'(\S+)\s+' + r'.\s+' + r'(\S+)\s+' + r'(\S+)', + line, + ) + if m: + method, param, score, error, unit = m.groups() + entries.append({ + 'method': method, + 'param': param, + 'score': float(score), + 'error': float(error), + 'unit': unit, + 'raw': [], + }) + return entries, {} + + +def parse_jmh_json(data): + """Parse JMH JSON output. Returns (entries, config_dict).""" + entries = [] + config = {} + total_sec = 0 + for i, result in enumerate(data): + bench = result['benchmark'].rsplit('.', 1)[-1] + params = result.get('params', {}) + # Handle multiple params: ScoreDocSortBenchmark uses 'size' and 'distribution' + size = params.get('size', '') + dist = params.get('distribution', 'random') + pm = result['primaryMetric'] + raw = [] + for fork_data in pm.get('rawData', []): + raw.extend(fork_data) + entries.append({ + 'method': bench, + 'size': size, + 'dist': dist, + 'score': pm['score'], + 'error': pm['scoreError'], + 'unit': pm['scoreUnit'], + 'raw': raw, + }) + + # Estimate total time for this benchmark + forks = result.get('forks', 0) + wi = result.get('warmupIterations', 0) + wt = result.get('warmupTime', '0 s') + mi = result.get('measurementIterations', 0) + mt = result.get('measurementTime', '0 s') + + def to_sec(t_str): + try: + val, unit = t_str.split() + val = float(val) + if unit == 'ms': return val / 1000 + if unit == 's': return val + if unit == 'min': return val * 60 + return 0 + except: return 0 + + total_sec += forks * (wi * to_sec(wt) + mi * to_sec(mt)) + + if i == 0: + mode_map = {'avgt': 'Average Time', 'thrpt': 'Throughput', + 'sample': 'Sampling', 'ss': 'Single Shot'} + # split jvmArgs into harness args (module-path, module-main) + # vs benchmark args (user/annotation provided like -Xmx, -XX:) + all_jvm_args = result.get('jvmArgs', []) + harness_prefixes = ('--module-path', '-Djdk.module.main', '-Djmh.') + harness_args = [a for a in all_jvm_args + if any(a.startswith(p) for p in harness_prefixes)] + benchmark_args = [a for a in all_jvm_args + if not any(a.startswith(p) for p in harness_prefixes)] + config = { + 'mode': mode_map.get(result.get('mode', ''), result.get('mode', '?')), + 'forks': result.get('forks', '?'), + 'threads': result.get('threads', '?'), + 'warmupIterations': result.get('warmupIterations', '?'), + 'warmupTime': result.get('warmupTime', '?'), + 'measurementIterations': result.get('measurementIterations', '?'), + 'measurementTime': result.get('measurementTime', '?'), + 'harnessJvmArgs': harness_args, + 'benchmarkJvmArgs': benchmark_args, + 'jvm': result.get('jvm', ''), + 'jdkVersion': result.get('jdkVersion', ''), + 'vmName': result.get('vmName', ''), + 'vmVersion': result.get('vmVersion', ''), + 'jmhVersion': result.get('jmhVersion', ''), + } + + if config: + if total_sec > 3600: + config['totalTime'] = f"{total_sec/3600:.1f} hours" + elif total_sec > 60: + config['totalTime'] = f"{total_sec/60:.1f} mins" + else: + config['totalTime'] = f"{total_sec:.1f} s" + + return entries, config + + +def extract_methods(source_path): + """Extract @Benchmark method bodies and their runXXX helpers from a Java source file. + + Returns dict of method_name -> source_code_string. + """ + methods = {} + if not source_path: + return methods + try: + with open(source_path, 'r') as f: + content = f.read() + except (OSError, IOError): + return methods + + # 1. Find all methods first (crude but effective for this benchmark style) + all_methods = {} + # Matches: [modifiers] [type] name([args]) { [body] } + # Handles nested braces + pos = 0 + while True: + m = re.search(r'(?:public|private|protected|static|\s)+\s+[\w<>[\]]+\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+[\w, \t]+)?\s*\{', content[pos:]) + if not m: + break + method_name = m.group(1) + start_brace = pos + m.end() - 1 + + # Find matching closing brace + depth = 0 + end_brace = -1 + for i in range(start_brace, len(content)): + if content[i] == '{': + depth += 1 + elif content[i] == '}': + depth -= 1 + if depth == 0: + end_brace = i + break + + if end_brace != -1: + # Find start of method (including annotations/comments) + method_start = pos + m.start() + # Look back for comments or annotations + lines = content[:method_start].splitlines() + actual_start = method_start + for i in range(len(lines) - 1, -1, -1): + line = lines[i].strip() + if line.startswith('@') or line.startswith('//') or line.startswith('*') or line.startswith('/*'): + actual_start = content.rfind(lines[i], 0, actual_start) + elif not line: + continue + else: + break + + body = content[actual_start:end_brace + 1] + # Dedent + lines = body.splitlines() + non_empty = [l for l in lines if l.strip()] + if non_empty: + min_indent = min(len(l) - len(l.lstrip()) for l in non_empty) + body = '\n'.join(l[min_indent:] if len(l) > min_indent else l for l in lines) + + all_methods[method_name] = body + pos = end_brace + 1 + else: + pos += m.end() + + # 2. Filter for @Benchmark methods and attach runXXX helpers + for name, body in all_methods.items(): + if '@Benchmark' in body: + # Look for runXXX call: e.g. runJdkSortLambda(work) + # Pattern: run followed by capitalized method name + run_name = "run" + name[0].upper() + name[1:] + if run_name in all_methods: + methods[name] = body + "\n\n" + all_methods[run_name] + else: + methods[name] = body + + return methods + + +def lerp_color(t): + """Green (t=0, best) -> yellow (t=0.5) -> red (t=1, worst).""" + t = max(0.0, min(1.0, t)) + if t < 0.5: + u = t * 2 + r = int(120 * u) + g = 180 + b = int(80 * (1 - u)) + else: + u = (t - 0.5) * 2 + r = 120 + int(100 * u) + g = int(180 * (1 - u)) + b = 0 + return r, g, b + + +def sparkline_svg(raw_samples, width=120, height=24, num_bins=20): + """Generate a tiny inline SVG histogram sparkline from raw samples.""" + if not raw_samples or len(raw_samples) < 2: + return '' + lo = min(raw_samples) + hi = max(raw_samples) + span = hi - lo + if span == 0: + span = 1 + bins = [0] * num_bins + for v in raw_samples: + idx = int((v - lo) / span * num_bins) + if idx >= num_bins: + idx = num_bins - 1 + bins[idx] += 1 + max_count = max(bins) + if max_count == 0: + return '' + bar_w = width / num_bins + bars = [] + for i, count in enumerate(bins): + bar_h = (count / max_count) * height + x = i * bar_w + y = height - bar_h + # Monochrome sparkline to avoid confusion with heatmap colors + r, g, b = 102, 136, 170 + bars.append( + f'' + ) + return ( + f'' + + ''.join(bars) + + '' + ) + + +def build_html(entries, config, method_sources): + if not entries: + print("No JMH results found on stdin.", file=sys.stderr) + sys.exit(1) + + has_raw = any(e['raw'] for e in entries) + has_source = bool(method_sources) + + seen_sizes = dict() + seen_methods = dict() + seen_dists = dict() + for e in entries: + seen_sizes[e['size']] = None + seen_methods[e['method']] = None + seen_dists[e['dist']] = None + + # Sort sizes numerically if possible + try: + sizes = sorted(seen_sizes.keys(), key=lambda x: int(x)) + except ValueError: + sizes = sorted(seen_sizes.keys()) + + methods = sorted(seen_methods.keys()) + dists = sorted(seen_dists.keys()) + unit = entries[0]['unit'] + + # grid[dist][method][size] = entry + grid = {} + for e in entries: + grid.setdefault(e['dist'], {}).setdefault(e['method'], {})[e['size']] = e + + # Precalculate mins/maxs per (dist, size) for heatmap + # stats[dist][size] = {min, max} + stats = {} + for d in dists: + stats[d] = {} + for s in sizes: + scores = [grid[d][m][s]['score'] for m in methods if s in grid[d].get(m, {})] + if scores: + stats[d][s] = {'min': min(scores), 'max': max(scores)} + + h = html.escape + + # JSON data for JS + # data_js[dist][method][size] = {score, error, rel, color, spark} + data_js = {} + for d in dists: + data_js[d] = {} + for m in methods: + data_js[d][m] = {} + for s in sizes: + if s in grid[d].get(m, {}): + e = grid[d][m][s] + score = e['score'] + lo, hi = stats[d][s]['min'], stats[d][s]['max'] + if lo > 0 and hi > lo: + t = math.log(score / lo) / math.log(hi / lo) + else: + span = hi - lo + t = (score - lo) / span if span > 0 else 0 + r, g, b = lerp_color(t) + rel = score / lo if lo > 0 else 1.0 + data_js[d][m][s] = { + 'score': f"{score:.3f}", + 'error': f"{e['error']:.3f}", + 'rel': f"{rel:.2f}×", + 'color': f"rgb({r},{g},{b})", + 'spark': sparkline_svg(e['raw']) if e['raw'] else '', + 'raw': e['raw'] + } + + sources_js = {name: src for name, src in method_sources.items()} + + out = [] + out.append(f""" +JMH Results + + +

JMH Results

""") + + # Config banner + if config: + out.append('') + items = [ + ('Mode', str(config.get('mode', '?'))), + ('Forks', str(config.get('forks', '?'))), + ('Threads', str(config.get('threads', '?'))), + ('Warmup', f"{config.get('warmupIterations','?')} iter \u00d7 {config.get('warmupTime','?')}"), + ('Measurement', f"{config.get('measurementIterations','?')} iter \u00d7 {config.get('measurementTime','?')}"), + ] + if config.get('totalTime'): + items.append(('Total time (approx)', config.get('totalTime'))) + jvm_desc = ' '.join(s for s in [config.get('vmName', ''), config.get('vmVersion', '')] if s) + if config.get('jdkVersion'): jvm_desc = f"JDK {config.get('jdkVersion')}, {jvm_desc}" + if jvm_desc: items.append(('JVM', jvm_desc)) + if config.get('jmhVersion'): items.append(('JMH version', config.get('jmhVersion'))) + if config.get('benchmarkJvmArgs'): items.append(('Fork JVM args', ' '.join(config.get('benchmarkJvmArgs')))) + for label, val in items: + out.append(f'') + out.append('
{h(label)}{h(val)}
') + + click_hint = '' + if has_raw or has_source: + click_hint = ' Click a data cell to see' + parts = [] + if has_raw: parts.append('its iteration histogram') + if has_source: parts.append('the method source code') + click_hint += ' ' + ' and '.join(parts) + '.' + + out.append(f'

Click column headers to sort.{click_hint}

') + + out.append('
') + out.append('
') + out.append('
') + out.append('
') + + out.append('
') + out.append('') + out.append('') + for i, s in enumerate(sizes): + out.append(f'') + out.append('') + + for method in methods: + out.append('') + out.append(f'') + for s in sizes: + out.append(f'') + out.append('') + + out.append('
Algorithmsize={h(s)}
{h(unit)}
{h(method)}
') + out.append('
') + out.append('

') + out.append('
') + out.append('
') + + out.append('') + return '\n'.join(out) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Parse JMH JSON/text output into an interactive HTML table.') + parser.add_argument('source', help='Path to Java source file containing @Benchmark methods') + parser.add_argument('--skip', nargs='+', default=[], metavar='ALGO', + help='Algorithm names to exclude (substring match, case-insensitive)') + args = parser.parse_args() + + method_sources = extract_methods(args.source) + if not method_sources: + print(f"No @Benchmark methods found in {args.source}", file=sys.stderr) + sys.exit(1) + + text = sys.stdin.read().strip() + if not text: + print("No input on stdin.", file=sys.stderr) + sys.exit(1) + + if text.startswith('[') or text.startswith('{'): + data = json.loads(text) + if isinstance(data, dict): + data = [data] + entries, config = parse_jmh_json(data) + else: + entries, config = parse_jmh_text(text) + + if args.skip: + skip_lower = [s.lower() for s in args.skip] + entries = [e for e in entries + if not any(sk in e['method'].lower() for sk in skip_lower)] + + print(build_html(entries, config, method_sources)) diff --git a/lucene/benchmark-jmh/run-benchmark.sh b/lucene/benchmark-jmh/run-benchmark.sh new file mode 100755 index 000000000000..4d0dde00123f --- /dev/null +++ b/lucene/benchmark-jmh/run-benchmark.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Compiles (if needed) and runs JMH benchmarks, passing all arguments through. +# +# Usage: +# ./lucene/benchmark-jmh/run-benchmark.sh ScoreDocSortBenchmark -rf json -rff results.json +# +# This ensures you never accidentally run stale bytecode. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ROOT_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)" + +: "${JAVA_HOME:=/usr/lib/jvm/java-25-openjdk}" +export JAVA_HOME +export JAVA25_HOME="$JAVA_HOME" +export RUNTIME_JAVA_HOME="$JAVA_HOME" + +echo "=== Compiling benchmarks ===" >&2 +JAVA_HOME="$JAVA_HOME" "$ROOT_DIR/gradlew" -p "$ROOT_DIR" :lucene:benchmark-jmh:assemble --quiet + +echo "=== Running JMH ===" >&2 +exec "$JAVA_HOME/bin/java" \ + --module-path "$ROOT_DIR/lucene/benchmark-jmh/build/benchmarks" \ + --module org.apache.lucene.benchmark.jmh \ + -jvmArgs "--sun-misc-unsafe-memory-access=allow --module-path=/l/trunk/lucene/benchmark-jmh/build/benchmarks -Djdk.module.main=org.apache.lucene.benchmark.jmh" \ + -rf json -rff results.json \ + "$@" diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/ScoreDocSortBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/ScoreDocSortBenchmark.java new file mode 100644 index 000000000000..1aef82cf1aee --- /dev/null +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/ScoreDocSortBenchmark.java @@ -0,0 +1,611 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.benchmark.jmh; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.IdentityHashMap; +import java.util.SplittableRandom; +import java.util.concurrent.TimeUnit; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.InPlaceMergeSorter; +import org.apache.lucene.util.IntroSorter; +import org.apache.lucene.util.LSBRadixSorter; +import org.apache.lucene.util.TimSorter; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * Benchmark comparing different sort implementations for sorting {@link ScoreDoc}[] by ascending + * doc ID. Simulates realistic ScoreDoc arrays with random doc IDs drawn from a 5M-doc index and + * random scores. + * + *

Running

+ * + * Use {@code run-benchmark.sh} which automatically recompiles if sources changed, then runs JMH: + * + *
{@code
+ * ./lucene/benchmark-jmh/run-benchmark.sh ScoreDocSortBenchmark \
+ *   -rf json -rff results.json
+ * }
+ * + *

Or build and run manually: + * + *

{@code
+ * ./gradlew :lucene:benchmark-jmh:assemble
+ * java --module-path lucene/benchmark-jmh/build/benchmarks \
+ *   --module org.apache.lucene.benchmark.jmh \
+ *   ScoreDocSortBenchmark \
+ *   -rf json -rff results.json
+ * }
+ * + *

Visualizing results

+ * + * The companion {@code jmh-table.py} script (in the same directory as this source file) converts + * JMH JSON output into an interactive HTML report: + * + *
{@code
+ * python3 lucene/benchmark-jmh/jmh-table.py \
+ *   lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/ScoreDocSortBenchmark.java \
+ *   < results.json > results.html
+ * }
+ * + *

The HTML report provides: + * + *

+ */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Thread) +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 5, time = 1) +@Fork( + value = 10, + jvmArgsAppend = {"-Xmx1g", "-Xms1g", "-XX:+AlwaysPreTouch"}) +public class ScoreDocSortBenchmark { + + private static final Comparator BY_DOC_ASC = (a, b) -> Integer.compare(a.doc, b.doc); + + private static final int MAX_DOC = 5_000_000; + + @Param({"10", "50", "100", "500", "1000", "10000"}) + int size; + + // add "nearly_sorted", "reversed" to test other distributions + @Param({"random"}) + String distribution; + + /** Template array; copied before each invocation so every sort sees the same random order. */ + private ScoreDoc[] template; + + /** Working copy that each benchmark method sorts in place. */ + private ScoreDoc[] work; + + @Setup(Level.Trial) + public void setupTrial() { + SplittableRandom rng = new SplittableRandom(0xCAFEBABE); + template = new ScoreDoc[size]; + for (int i = 0; i < size; i++) { + int doc = rng.nextInt(MAX_DOC); + float score = (float) rng.nextDouble(0.0, 10.0); + template[i] = new ScoreDoc(doc, score); + } + + if (distribution.equals("nearly_sorted")) { + Arrays.sort(template, BY_DOC_ASC); + // swap ~5% of adjacent pairs to introduce mild disorder + int numSwaps = (int) (size * 0.05); + for (int i = 0; i < numSwaps; i++) { + int idx = rng.nextInt(size - 1); + ScoreDoc tmp = template[idx]; + template[idx] = template[idx + 1]; + template[idx + 1] = tmp; + } + } else if (distribution.equals("reversed")) { + Arrays.sort(template, BY_DOC_ASC); + for (int i = 0; i < size / 2; i++) { + ScoreDoc tmp = template[i]; + template[i] = template[size - 1 - i]; + template[size - 1 - i] = tmp; + } + } + + // verification - runs once up front per trial (per parameter set) + ScoreDoc[] reference = Arrays.copyOf(template, size); + Arrays.sort(reference, BY_DOC_ASC); + + safeVerify("jdkSortLambda", reference, runJdkSortLambda(Arrays.copyOf(template, size))); + safeVerify("jdkSortComparator", reference, runJdkSortComparator(Arrays.copyOf(template, size))); + safeVerify( + "arrayUtilIntroSort", reference, runArrayUtilIntroSort(Arrays.copyOf(template, size))); + safeVerify("arrayUtilTimSort", reference, runArrayUtilTimSort(Arrays.copyOf(template, size))); + safeVerify( + "introSorterAnonymous", reference, runIntroSorterAnonymous(Arrays.copyOf(template, size))); + safeVerify( + "timSorterAnonymous", reference, runTimSorterAnonymous(Arrays.copyOf(template, size))); + safeVerify( + "inPlaceMergeSorterAnonymous", + reference, + runInPlaceMergeSorterAnonymous(Arrays.copyOf(template, size))); + safeVerify("jdkParallelSort", reference, runJdkParallelSort(Arrays.copyOf(template, size))); + safeVerify( + "jdkSortPrimitiveExtractLong", + reference, + runJdkSortPrimitiveExtractLong(Arrays.copyOf(template, size))); + safeVerify( + "jdkSortPrimitiveExtractAdaptive", + reference, + runJdkSortPrimitiveExtractAdaptive(Arrays.copyOf(template, size))); + safeVerify( + "lsbRadixSortExtract", reference, runLsbRadixSortExtract(Arrays.copyOf(template, size))); + safeVerify("radixSort2Pass", reference, runRadixSort2Pass(Arrays.copyOf(template, size))); + } + + private void safeVerify(String name, ScoreDoc[] reference, ScoreDoc[] result) { + try { + verify(name, reference, result); + } catch (IllegalStateException e) { + System.err.println("WARNING: " + e.getMessage()); + } + } + + private void verify(String name, ScoreDoc[] reference, ScoreDoc[] result) { + if (result.length != reference.length) { + throw new IllegalStateException( + name + + " failed: length mismatch. expected " + + reference.length + + " but got " + + result.length); + } + for (int i = 0; i < result.length; i++) { + if (i > 0 && result[i].doc < result[i - 1].doc) { + throw new IllegalStateException( + name + + " failed: not sorted at index " + + i + + ". " + + result[i - 1].doc + + " > " + + result[i].doc); + } + // check if doc matches reference (handles duplicates correctly since both are doc-sorted) + if (result[i].doc != reference[i].doc) { + throw new IllegalStateException( + name + + " failed: doc mismatch at index " + + i + + ". expected " + + reference[i].doc + + " but got " + + result[i].doc); + } + } + // integrity check: ensure we didn't lose or duplicate objects + IdentityHashMap counts = new IdentityHashMap<>(); + for (ScoreDoc sd : template) { + counts.merge(sd, 1, Integer::sum); + } + for (ScoreDoc sd : result) { + Integer c = counts.get(sd); + if (c == null) { + throw new IllegalStateException( + name + " failed: result contains unknown ScoreDoc instance"); + } + if (c == 1) { + counts.remove(sd); + } else { + counts.put(sd, c - 1); + } + } + if (counts.isEmpty() == false) { + throw new IllegalStateException(name + " failed: result missing ScoreDoc instances"); + } + } + + /** + * setupInvocation performs a shallow copy of the template. + * + *

Note: using Level.Invocation introduces overhead that JMH cannot easily subtract. For very + * small sizes (e.g. size=10), this overhead might be comparable to the benchmarked sort itself. + * We accept this because each invocation must start with the same unsorted array to ensure + * reproducibility across different sorting algorithms. + */ + @Setup(Level.Invocation) + public void setupInvocation() { + work = new ScoreDoc[size]; + System.arraycopy(template, 0, work, 0, size); + } + + // ---- 1. JDK Arrays.sort with lambda ---- + + private ScoreDoc[] runJdkSortLambda(ScoreDoc[] work) { + Arrays.sort(work, (a, b) -> Integer.compare(a.doc, b.doc)); + return work; + } + + @Benchmark + public void jdkSortLambda(Blackhole bh) { + // intentionally inline — tests whether JIT handles inline lambda differently than static + // comparator + bh.consume(runJdkSortLambda(work)); + } + + // ---- 2. JDK Arrays.sort with static comparator ---- + + private ScoreDoc[] runJdkSortComparator(ScoreDoc[] work) { + Arrays.sort(work, BY_DOC_ASC); + return work; + } + + @Benchmark + public void jdkSortComparator(Blackhole bh) { + bh.consume(runJdkSortComparator(work)); + } + + // ---- 3. ArrayUtil.introSort (wraps ArrayIntroSorter) ---- + + private ScoreDoc[] runArrayUtilIntroSort(ScoreDoc[] work) { + ArrayUtil.introSort(work, BY_DOC_ASC); + return work; + } + + @Benchmark + public void arrayUtilIntroSort(Blackhole bh) { + bh.consume(runArrayUtilIntroSort(work)); + } + + // ---- 4. ArrayUtil.timSort (wraps ArrayTimSorter) ---- + + private ScoreDoc[] runArrayUtilTimSort(ScoreDoc[] work) { + ArrayUtil.timSort(work, BY_DOC_ASC); + return work; + } + + @Benchmark + public void arrayUtilTimSort(Blackhole bh) { + bh.consume(runArrayUtilTimSort(work)); + } + + // ---- 5. Anonymous IntroSorter ---- + + private ScoreDoc[] runIntroSorterAnonymous(ScoreDoc[] work) { + final ScoreDoc[] arr = work; + new IntroSorter() { + ScoreDoc pivot; + + @Override + protected void swap(int i, int j) { + ScoreDoc tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + + @Override + protected void setPivot(int i) { + pivot = arr[i]; + } + + @Override + protected int comparePivot(int j) { + return Integer.compare(pivot.doc, arr[j].doc); + } + + @Override + protected int compare(int i, int j) { + return Integer.compare(arr[i].doc, arr[j].doc); + } + }.sort(0, arr.length); + return arr; + } + + @Benchmark + public void introSorterAnonymous(Blackhole bh) { + bh.consume(runIntroSorterAnonymous(work)); + } + + // ---- 6. Anonymous TimSorter ---- + + private ScoreDoc[] runTimSorterAnonymous(ScoreDoc[] work) { + final ScoreDoc[] arr = work; + final int len = arr.length; + new TimSorter(len / 2) { + ScoreDoc[] tmp = new ScoreDoc[len / 2]; + + @Override + protected void swap(int i, int j) { + ScoreDoc t = arr[i]; + arr[i] = arr[j]; + arr[j] = t; + } + + @Override + protected int compare(int i, int j) { + return Integer.compare(arr[i].doc, arr[j].doc); + } + + @Override + protected void copy(int src, int dest) { + arr[dest] = arr[src]; + } + + @Override + protected void save(int start, int l) { + System.arraycopy(arr, start, tmp, 0, l); + } + + @Override + protected void restore(int src, int dest) { + arr[dest] = tmp[src]; + } + + @Override + protected int compareSaved(int i, int j) { + return Integer.compare(tmp[i].doc, arr[j].doc); + } + }.sort(0, len); + return arr; + } + + @Benchmark + public void timSorterAnonymous(Blackhole bh) { + bh.consume(runTimSorterAnonymous(work)); + } + + // ---- 7. Anonymous InPlaceMergeSorter ---- + + private ScoreDoc[] runInPlaceMergeSorterAnonymous(ScoreDoc[] work) { + final ScoreDoc[] arr = work; + new InPlaceMergeSorter() { + @Override + protected void swap(int i, int j) { + ScoreDoc tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + + @Override + protected int compare(int i, int j) { + return Integer.compare(arr[i].doc, arr[j].doc); + } + }.sort(0, arr.length); + return arr; + } + + @Benchmark + public void inPlaceMergeSorterAnonymous(Blackhole bh) { + bh.consume(runInPlaceMergeSorterAnonymous(work)); + } + + // ---- 8. JDK Arrays.parallelSort with static comparator ---- + + private ScoreDoc[] runJdkParallelSort(ScoreDoc[] work) { + Arrays.parallelSort(work, BY_DOC_ASC); + return work; + } + + @Benchmark + public void jdkParallelSort(Blackhole bh) { + bh.consume(runJdkParallelSort(work)); + } + + // ---- 9. Extract doc IDs, sort with JDK Arrays.sort (primitive long[]), reorder ---- + + private ScoreDoc[] runJdkSortPrimitiveExtractLong(ScoreDoc[] work) { + int len = work.length; + // pack (doc, originalIndex) into a long: doc in upper 32, index in lower 32 + long[] packed = new long[len]; + for (int i = 0; i < len; i++) { + packed[i] = ((long) work[i].doc << 32) | (i & 0xFFFFFFFFL); + } + Arrays.sort(packed); + ScoreDoc[] sorted = new ScoreDoc[len]; + for (int i = 0; i < len; i++) { + sorted[i] = work[(int) packed[i]]; + } + return sorted; + } + + @Benchmark + public void jdkSortPrimitiveExtractLong(Blackhole bh) { + bh.consume(runJdkSortPrimitiveExtractLong(work)); + } + + // ---- 10. Extract doc IDs, sort with int[] when bits fit, else long[] ---- + + // bits needed to represent values in [0, max) + private static int bitsNeeded(int max) { + return 32 - Integer.numberOfLeadingZeros(max - 1); + } + + private ScoreDoc[] runJdkSortPrimitiveExtractAdaptive(ScoreDoc[] work) { + int len = work.length; + int docBits = bitsNeeded(MAX_DOC); + int indexBits = bitsNeeded(len); + if (docBits + indexBits <= 31) { + // pack into int[]: doc in upper bits, index in lower bits + // <= 31 (not 32) because Arrays.sort uses signed comparison, + // so bit 31 must stay clear to avoid sign-bit corruption + int[] packed = new int[len]; + for (int i = 0; i < len; i++) { + packed[i] = (work[i].doc << indexBits) | i; + } + Arrays.sort(packed); + int indexMask = (1 << indexBits) - 1; + ScoreDoc[] sorted = new ScoreDoc[len]; + for (int i = 0; i < len; i++) { + sorted[i] = work[packed[i] & indexMask]; + } + return sorted; + } else { + // fall back to long[] + long[] packed = new long[len]; + for (int i = 0; i < len; i++) { + packed[i] = ((long) work[i].doc << 32) | (i & 0xFFFFFFFFL); + } + Arrays.sort(packed); + ScoreDoc[] sorted = new ScoreDoc[len]; + for (int i = 0; i < len; i++) { + sorted[i] = work[(int) packed[i]]; + } + return sorted; + } + } + + @Benchmark + public void jdkSortPrimitiveExtractAdaptive(Blackhole bh) { + /** + * Documentation of int vs long paths given MAX_DOC = 5,000,000: + * + *

    + *
  • sizes 10, 50, 100 take the int[] path (23 + 7 = 30 <= 31 bits) + *
  • sizes 500, 1,000, 10,000 take the long[] path (23 + 9 = 32 > 31 bits) + *
+ */ + bh.consume(runJdkSortPrimitiveExtractAdaptive(work)); + } + + // ---- 11. Extract doc IDs, sort with LSBRadixSorter when bits fit, else JDK long[] ---- + + private ScoreDoc[] runLsbRadixSortExtract(ScoreDoc[] work) { + int len = work.length; + int docBits = bitsNeeded(MAX_DOC); + int indexBits = bitsNeeded(len); + if (docBits + indexBits <= 32) { + int[] packed = new int[len]; + for (int i = 0; i < len; i++) { + packed[i] = (work[i].doc << indexBits) | i; + } + new LSBRadixSorter().sort(docBits + indexBits, packed, len); + int indexMask = (1 << indexBits) - 1; + ScoreDoc[] sorted = new ScoreDoc[len]; + for (int i = 0; i < len; i++) { + sorted[i] = work[packed[i] & indexMask]; + } + return sorted; + } else { + // fallback to long[] + Arrays.sort + long[] packed = new long[len]; + for (int i = 0; i < len; i++) { + packed[i] = ((long) work[i].doc << 32) | (i & 0xFFFFFFFFL); + } + Arrays.sort(packed); + ScoreDoc[] sorted = new ScoreDoc[len]; + for (int i = 0; i < len; i++) { + sorted[i] = work[(int) packed[i]]; + } + return sorted; + } + } + + @Benchmark + public void lsbRadixSortExtract(Blackhole bh) { + bh.consume(runLsbRadixSortExtract(work)); + } + + // ---- 12. Extract doc IDs, manual 4-pass radix sort (8-bit) ---- + + private ScoreDoc[] runRadixSort2Pass(ScoreDoc[] work) { + int len = work.length; + int docBits = bitsNeeded(MAX_DOC); + int indexBits = bitsNeeded(len); + if (docBits + indexBits <= 32) { + int[] packed = new int[len]; + for (int i = 0; i < len; i++) { + packed[i] = (work[i].doc << indexBits) | i; + } + + int totalBits = docBits + indexBits; + int[] bucket = new int[256]; + int[] workArray = new int[len]; + + // up to 4 passes over 8-bit radix, skip unnecessary high passes + int passes = (totalBits + 7) >>> 3; // ceil(totalBits / 8) + for (int pass = 0; pass < passes; pass++) { + int shift = pass * 8; + int[] src = (pass % 2 == 0) ? packed : workArray; + int[] dst = (pass % 2 == 0) ? workArray : packed; + + // histogram + for (int i = 0; i < len; i++) { + bucket[(src[i] >>> shift) & 0xFF]++; + } + // prefix sum + for (int i = 1; i < 256; i++) { + bucket[i] += bucket[i - 1]; + } + // scatter + for (int i = len - 1; i >= 0; i--) { + dst[--bucket[(src[i] >>> shift) & 0xFF]] = src[i]; + } + + Arrays.fill(bucket, 0); + } + + // if odd number of passes, result is in workArray + int[] sorted_packed = (passes % 2 == 0) ? packed : workArray; + + int indexMask = (1 << indexBits) - 1; + ScoreDoc[] sorted = new ScoreDoc[len]; + for (int i = 0; i < len; i++) { + sorted[i] = work[sorted_packed[i] & indexMask]; + } + return sorted; + } else { + // long fallback + long[] packed = new long[len]; + for (int i = 0; i < len; i++) { + packed[i] = ((long) work[i].doc << 32) | (i & 0xFFFFFFFFL); + } + Arrays.sort(packed); + ScoreDoc[] sorted = new ScoreDoc[len]; + for (int i = 0; i < len; i++) { + sorted[i] = work[(int) packed[i]]; + } + return sorted; + } + } + + @Benchmark + public void radixSort2Pass(Blackhole bh) { + bh.consume(runRadixSort2Pass(work)); + } +}