Skip to content

Commit 09187f0

Browse files
feat: add prompt relevance scoring — gate 2 for candidate ranking
New module stylometry/relevance.rs measures content-word overlap between prompt and generated text, with fuzzy prefix matching for morphological variants (dolphins→dolphin). Wired into: - Candidate ranker: combined score = style_distance + (1-relevance)*0.3 so off-topic candidates are penalized even if well-styled - Eval harness: prompt_relevance tracked in EvalRecord, CSV, and summary 5 unit tests for relevance scoring.
1 parent dfa14bd commit 09187f0

5 files changed

Lines changed: 247 additions & 10 deletions

File tree

src/commands/eval_style.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use writer_cli::stylometry::features::lengths;
2020
use writer_cli::stylometry::features::punctuation::PunctuationStats;
2121
use writer_cli::stylometry::features::readability::ReadabilityStats;
2222
use writer_cli::stylometry::fingerprint::StylometricFingerprint;
23+
use writer_cli::stylometry::relevance;
2324

2425
use crate::config;
2526
use crate::error::AppError;
@@ -59,6 +60,7 @@ struct EvalRecord {
5960
questions_per_1k: f64,
6061
exclamations_per_1k: f64,
6162
canon_leakage_score: f64,
63+
prompt_relevance: f64,
6264
// Generation config
6365
system_prompt_enabled: bool,
6466
prompt_wrapping_enabled: bool,
@@ -82,6 +84,7 @@ struct EvalSummary {
8284
mean_questions_per_1k: f64,
8385
mean_exclamations_per_1k: f64,
8486
mean_canon_leakage: f64,
87+
mean_prompt_relevance: f64,
8588
raw_mode: bool,
8689
adapter_used: bool,
8790
model: String,
@@ -245,6 +248,7 @@ pub async fn run(
245248
let punct = PunctuationStats::compute(&text);
246249
let read = ReadabilityStats::compute(&text);
247250
let canon_leakage = compute_canon_leakage(&text, &prompt_entry.text, &leakage_lexicon);
251+
let prompt_rel = relevance::score(&prompt_entry.text, &text);
248252

249253
let record = EvalRecord {
250254
prompt: prompt_entry.text.clone(),
@@ -258,6 +262,7 @@ pub async fn run(
258262
questions_per_1k: punct.questions_per_1k,
259263
exclamations_per_1k: punct.exclamations_per_1k,
260264
canon_leakage_score: canon_leakage,
265+
prompt_relevance: prompt_rel,
261266
system_prompt_enabled: system.is_some(),
262267
prompt_wrapping_enabled: !raw,
263268
raw_mode: raw,
@@ -329,6 +334,10 @@ pub async fn run(
329334
summary.mean_questions_per_1k, summary.mean_exclamations_per_1k
330335
);
331336
println!(" mean canon leakage: {:.3}", summary.mean_canon_leakage);
337+
println!(
338+
" mean prompt relevance: {:.3}",
339+
summary.mean_prompt_relevance
340+
);
332341
println!("\n results: {}", output_dir.display().to_string().dimmed());
333342
} else {
334343
crate::output::print_success_or(ctx, &summary, |_| {});
@@ -425,13 +434,13 @@ fn contains_whole_word(needle: &str, haystack: &str) -> bool {
425434

426435
fn write_csv(path: &Path, records: &[EvalRecord]) -> Result<(), AppError> {
427436
let mut out = String::new();
428-
out.push_str("prompt,category,seed,style_distance,sentence_length_mean,sentence_length_sd,fk_grade,questions_per_1k,exclamations_per_1k,canon_leakage_score,system_prompt,prompt_wrapping,raw_mode,adapter,n_candidates,model\n");
437+
out.push_str("prompt,category,seed,style_distance,sentence_length_mean,sentence_length_sd,fk_grade,questions_per_1k,exclamations_per_1k,canon_leakage_score,prompt_relevance,system_prompt,prompt_wrapping,raw_mode,adapter,n_candidates,model\n");
429438

430439
for r in records {
431440
// CSV-escape the prompt
432441
let prompt_escaped = r.prompt.replace('"', "\"\"");
433442
out.push_str(&format!(
434-
"\"{}\",\"{}\",{},{:.4},{:.2},{:.2},{:.2},{:.2},{:.2},{:.4},{},{},{},{},{},{}\n",
443+
"\"{}\",\"{}\",{},{:.4},{:.2},{:.2},{:.2},{:.2},{:.2},{:.4},{:.4},{},{},{},{},{},{}\n",
435444
prompt_escaped,
436445
r.category,
437446
r.seed,
@@ -442,6 +451,7 @@ fn write_csv(path: &Path, records: &[EvalRecord]) -> Result<(), AppError> {
442451
r.questions_per_1k,
443452
r.exclamations_per_1k,
444453
r.canon_leakage_score,
454+
r.prompt_relevance,
445455
r.system_prompt_enabled,
446456
r.prompt_wrapping_enabled,
447457
r.raw_mode,
@@ -477,6 +487,7 @@ fn compute_summary(
477487
mean_questions_per_1k: 0.0,
478488
mean_exclamations_per_1k: 0.0,
479489
mean_canon_leakage: 0.0,
490+
mean_prompt_relevance: 0.0,
480491
raw_mode: raw,
481492
adapter_used: adapter,
482493
model: model_id.to_string(),
@@ -507,6 +518,7 @@ fn compute_summary(
507518
mean_questions_per_1k: records.iter().map(|r| r.questions_per_1k).sum::<f64>() / n,
508519
mean_exclamations_per_1k: records.iter().map(|r| r.exclamations_per_1k).sum::<f64>() / n,
509520
mean_canon_leakage: records.iter().map(|r| r.canon_leakage_score).sum::<f64>() / n,
521+
mean_prompt_relevance: records.iter().map(|r| r.prompt_relevance).sum::<f64>() / n,
510522
raw_mode: raw,
511523
adapter_used: adapter,
512524
model: model_id.to_string(),

src/decoding/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ pub async fn run(
153153
return Err(DecodingError::Backend(err_detail));
154154
}
155155

156-
// Rank candidates by stylometric distance
157-
let ranked = ranker::rank(&candidates, fingerprint);
156+
// Rank candidates by style distance + prompt relevance
157+
let ranked = ranker::rank(&candidates, fingerprint, prompt);
158158

159159
// Filter best candidate
160160
let (best_vec_idx, best_distance) = ranked[0];

src/decoding/ranker.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,36 @@
1-
//! Rank generated candidates by stylometric distance to the user's fingerprint.
1+
//! Rank generated candidates by combined style fidelity and prompt relevance.
22
//!
33
//! Reference: PAN authorship verification — cosine/distance-based ranking.
44
use crate::stylometry::fingerprint::StylometricFingerprint;
5-
use crate::stylometry::scoring;
5+
use crate::stylometry::{relevance, scoring};
66

7-
/// Rank candidates by stylometric distance to the fingerprint.
8-
/// Returns vec of (candidate_index, distance), sorted lowest distance first.
7+
/// Rank candidates by combined score: style distance penalized by low relevance.
8+
///
9+
/// Scoring: `combined = style_distance + relevance_penalty`
10+
/// where `relevance_penalty = (1.0 - relevance) * 0.3`
11+
///
12+
/// This means: a perfectly relevant but stylistically distant candidate (0.6 + 0.0)
13+
/// beats an off-topic but well-styled candidate (0.3 + 0.3).
14+
///
15+
/// Returns vec of (candidate_index, combined_score), sorted lowest first.
916
pub fn rank(
1017
candidates: &[(String, u32, u64)],
1118
fingerprint: &StylometricFingerprint,
19+
prompt: &str,
1220
) -> Vec<(usize, f64)> {
1321
let mut scored: Vec<(usize, f64)> = candidates
1422
.iter()
1523
.enumerate()
1624
.map(|(i, (text, _, _))| {
1725
let report = scoring::distance(text, fingerprint);
18-
(i, report.overall)
26+
let rel = relevance::score(prompt, text);
27+
// Penalty: up to 0.3 for completely irrelevant output
28+
let relevance_penalty = (1.0 - rel) * 0.3;
29+
let combined = (report.overall + relevance_penalty).clamp(0.0, 1.0);
30+
(i, combined)
1931
})
2032
.collect();
2133

22-
// Sort by distance ascending (closest to user's voice first)
2334
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
2435
scored
2536
}

src/stylometry/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
pub mod ai_slop;
33
pub mod features;
44
pub mod fingerprint;
5+
pub mod relevance;
56
pub mod scoring;

src/stylometry/relevance.rs

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
//! Prompt relevance scoring.
2+
//!
3+
//! Measures whether generated text addresses the prompt's topic.
4+
//! Without this, the ranker could prefer a beautifully-styled off-topic
5+
//! response over a relevant one.
6+
//!
7+
//! Approach: content-word overlap — extract meaningful words from the prompt
8+
//! (excluding stop words), then measure what fraction appear in the output.
9+
//! Simple, fast, and sufficient for a generation-time gate.
10+
11+
use unicode_segmentation::UnicodeSegmentation;
12+
13+
/// Common English stop words — excluded from content-word extraction.
14+
/// Kept minimal to avoid false negatives on short prompts.
15+
const STOP_WORDS: &[&str] = &[
16+
"a",
17+
"an",
18+
"the",
19+
"and",
20+
"or",
21+
"but",
22+
"in",
23+
"on",
24+
"at",
25+
"to",
26+
"for",
27+
"of",
28+
"with",
29+
"by",
30+
"from",
31+
"is",
32+
"are",
33+
"was",
34+
"were",
35+
"be",
36+
"been",
37+
"being",
38+
"have",
39+
"has",
40+
"had",
41+
"do",
42+
"does",
43+
"did",
44+
"will",
45+
"would",
46+
"could",
47+
"should",
48+
"may",
49+
"might",
50+
"shall",
51+
"can",
52+
"this",
53+
"that",
54+
"these",
55+
"those",
56+
"it",
57+
"its",
58+
"i",
59+
"you",
60+
"he",
61+
"she",
62+
"we",
63+
"they",
64+
"me",
65+
"him",
66+
"her",
67+
"us",
68+
"them",
69+
"my",
70+
"your",
71+
"his",
72+
"our",
73+
"their",
74+
"what",
75+
"which",
76+
"who",
77+
"whom",
78+
"how",
79+
"when",
80+
"where",
81+
"why",
82+
"not",
83+
"no",
84+
"so",
85+
"if",
86+
"about",
87+
"up",
88+
"out",
89+
"just",
90+
"than",
91+
"then",
92+
"also",
93+
"very",
94+
"some",
95+
"any",
96+
"all",
97+
"each",
98+
"every",
99+
"into",
100+
"as",
101+
"write",
102+
"writing",
103+
"paragraph",
104+
"essay",
105+
"piece",
106+
"about",
107+
"describe",
108+
"explain",
109+
"tell",
110+
];
111+
112+
/// Extract content words from text — lowercase, alphabetic, non-stop-word.
113+
fn content_words(text: &str) -> Vec<String> {
114+
text.unicode_words()
115+
.map(|w| w.to_lowercase())
116+
.filter(|w| w.len() >= 3 && w.chars().all(|c| c.is_alphabetic()))
117+
.filter(|w| !STOP_WORDS.contains(&w.as_str()))
118+
.collect()
119+
}
120+
121+
/// Compute prompt relevance as the fraction of prompt content words
122+
/// that appear at least once in the output.
123+
///
124+
/// Returns a score in [0.0, 1.0] where:
125+
/// - 1.0 = every prompt content word appears in the output
126+
/// - 0.0 = no prompt content words appear in the output
127+
///
128+
/// If the prompt has no content words (e.g., "Write something"), returns 1.0
129+
/// to avoid penalizing vague prompts.
130+
pub fn score(prompt: &str, output: &str) -> f64 {
131+
let prompt_words = content_words(prompt);
132+
if prompt_words.is_empty() {
133+
return 1.0;
134+
}
135+
136+
// Deduplicate prompt words
137+
let unique_prompt: std::collections::HashSet<&str> =
138+
prompt_words.iter().map(|s| s.as_str()).collect();
139+
140+
let _output_lower = output.to_lowercase();
141+
let output_words: std::collections::HashSet<String> =
142+
output.unicode_words().map(|w| w.to_lowercase()).collect();
143+
144+
let mut found = 0;
145+
for word in &unique_prompt {
146+
if output_words.contains(*word) {
147+
found += 1;
148+
} else {
149+
// Fuzzy: check shared prefix >= 4 chars for morphological variants.
150+
// "dolphins" matches "dolphin", "swimming" matches "swims", etc.
151+
let min_prefix = word.len().min(4);
152+
let prefix = &word[..min_prefix];
153+
if output_words
154+
.iter()
155+
.any(|ow| ow.starts_with(prefix) && common_prefix_len(word, ow) >= min_prefix)
156+
{
157+
found += 1;
158+
}
159+
}
160+
}
161+
162+
found as f64 / unique_prompt.len() as f64
163+
}
164+
165+
/// Length of the common prefix between two strings.
166+
fn common_prefix_len(a: &str, b: &str) -> usize {
167+
a.chars().zip(b.chars()).take_while(|(x, y)| x == y).count()
168+
}
169+
170+
#[cfg(test)]
171+
mod tests {
172+
use super::*;
173+
174+
#[test]
175+
fn perfect_relevance() {
176+
let prompt = "Write about dolphins and ocean life";
177+
let output =
178+
"Dolphins are fascinating creatures of the ocean. Their life underwater is complex.";
179+
let s = score(prompt, output);
180+
assert!(s >= 0.5, "score {s} should be at least 0.5");
181+
}
182+
183+
#[test]
184+
fn zero_relevance() {
185+
let prompt = "Write about quantum physics and black holes";
186+
let output = "The garden was beautiful with roses and tulips blooming everywhere.";
187+
let s = score(prompt, output);
188+
assert!(s < 0.3, "score {s} should be low for off-topic output");
189+
}
190+
191+
#[test]
192+
fn vague_prompt_returns_one() {
193+
// All words are stop words or too short → no content words → 1.0
194+
let prompt = "Write about it for me";
195+
let output = "Anything at all.";
196+
assert_eq!(score(prompt, output), 1.0);
197+
}
198+
199+
#[test]
200+
fn empty_output() {
201+
let prompt = "Write about dolphins";
202+
assert_eq!(score(prompt, ""), 0.0);
203+
}
204+
205+
#[test]
206+
fn morphological_variant() {
207+
let prompt = "Write about dolphins swimming";
208+
let output = "A dolphin swims gracefully through the water.";
209+
let s = score(prompt, output);
210+
// "dolphins" should partially match "dolphin" via substring
211+
assert!(s > 0.0, "score {s} should be > 0 for morphological matches");
212+
}
213+
}

0 commit comments

Comments
 (0)