Skip to content

Commit ec5bac5

Browse files
committed
various fixes to VSA, disable graph pipeline for now
1 parent 0b3342d commit ec5bac5

8 files changed

Lines changed: 72 additions & 51 deletions

File tree

lib/pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ dependencies = [
2424
"pystemmer",
2525
"lxml",
2626
"namedlist",
27-
"sentence-transformers~=5.1",
27+
"sentence-transformers~=5.3",
28+
"transformers~=4.51.0",
2829
"lz4",
2930
"orjson",
3031
"text_preprocessing @ git+https://github.com/ARTFL-Project/text-preprocessing@v1.1.1.3#egg=text_preprocessing",
@@ -39,7 +40,6 @@ dependencies = [
3940
"ahocorasick-rs",
4041
"msgspec",
4142
"faiss-cpu",
42-
"spacy-transformers",
4343
"networkx~=3.5",
4444
"torch-geometric~=2.7.0",
4545
"umap-learn~=0.5.9",
@@ -54,7 +54,7 @@ Documentation = "https://github.com/ARTFL-Project/text-pair#readme"
5454
"Bug Tracker" = "https://github.com/ARTFL-Project/text-pair/issues"
5555

5656
[project.scripts]
57-
textpair = "textpair.__main__:main"
57+
textpair = "textpair.__main__:run"
5858

5959
[project.optional-dependencies]
6060
cpu = [
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[console_scripts]
2-
textpair = textpair.__main__:main
2+
textpair = textpair.__main__:run

lib/textpair/__main__.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import configparser
55
import os
6+
import shutil
67
import subprocess
78
import sys
89

@@ -302,10 +303,10 @@ async def run_alignment(params):
302303
groups_file = merge_alignments(results_file, count)
303304

304305
if params.web_app_config["skip_web_app"] is False:
305-
# Build graph model and generate cluster labels
306-
print(f"\n### Building Thematic Identity Graph model ###")
307-
embedding_model = params.preprocessing_params["source"]["embedding_model"]
308-
build_graph_and_labels(results_file, embedding_model, params.llm_params)
306+
# Graph pipeline disabled for now
307+
# print(f"\n### Building Thematic Identity Graph model ###")
308+
# embedding_model = params.preprocessing_params["source"]["embedding_model"]
309+
# build_graph_and_labels(results_file, embedding_model, params.llm_params)
309310

310311
create_web_app(
311312
results_file,
@@ -383,9 +384,9 @@ async def run_vsa_similarity(params) -> None:
383384
output_file = os.path.join(params.output_path, "results/alignments.jsonl.lz4")
384385
count = get_count(os.path.join(params.output_path, "results/counts.txt"))
385386

386-
# Build graph model and generate cluster labels
387-
embedding_model = params.preprocessing_params["source"]["embedding_model"]
388-
build_graph_and_labels(output_file, embedding_model, params.llm_params)
387+
# Graph pipeline disabled for now
388+
# embedding_model = params.preprocessing_params["source"]["embedding_model"]
389+
# build_graph_and_labels(output_file, embedding_model, params.llm_params)
389390

390391
create_web_app(
391392
output_file,
@@ -407,6 +408,13 @@ async def run_vsa_similarity(params) -> None:
407408
async def main():
408409
"""Main entry point for the textpair CLI."""
409410
params = get_config()
411+
412+
# Save a copy of the config file to the output directory for reproducibility
413+
config_file = params.config
414+
if config_file and os.path.exists(config_file):
415+
os.makedirs(params.output_path, exist_ok=True)
416+
shutil.copy2(config_file, os.path.join(params.output_path, f"{params.dbname}_config.ini"))
417+
410418
if params.delete is True:
411419
delete_database(params.dbname)
412420
elif params.update_db is True:
@@ -462,7 +470,12 @@ async def main():
462470
await run_vsa_similarity(params)
463471

464472

465-
if __name__ == "__main__":
473+
def run():
474+
"""Sync entry point for console_scripts."""
466475
import asyncio
467476

468477
asyncio.run(main())
478+
479+
480+
if __name__ == "__main__":
481+
run()

lib/textpair/vector_space_alignment/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def generate_merged_groups():
240240
return count
241241

242242
# Perform iterative merging with streaming
243-
while last_count / current_count <= 1.0:
243+
while last_count / current_count < 1.0:
244244
last_count = current_count
245245

246246
# Alternate between temp databases

lib/textpair/vector_space_alignment/corpus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def sim_function(x, y):
453453
)
454454

455455
if model is None:
456-
self.model = SentenceTransformer(model_name, trust_remote_code=False)
456+
self.model = SentenceTransformer(model_name)
457457
else:
458458
self.model = model
459459

@@ -484,7 +484,7 @@ def create_embeddings(self, text_chunks) -> torch.Tensor:
484484
tensor = self.model.encode(
485485
list(text_chunks),
486486
convert_to_tensor=True,
487-
batch_size=512,
487+
batch_size=32,
488488
show_progress_bar=False,
489489
normalize_embeddings=True,
490490
)

lib/textpair/vector_space_alignment/expansion.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,6 @@ async def expand_validated_matches(
328328
expansion_candidates = []
329329
final_matches = []
330330

331-
print("Identifying expansion candidates...", flush=True)
332-
333331
for match in matches:
334332
source_sents = count_sentences_from_tokens(
335333
match.source.metadata["parsed_filename"],
@@ -354,31 +352,20 @@ async def expand_validated_matches(
354352
total_candidates = len(expansion_candidates)
355353

356354
if expansion_candidates:
357-
print(
358-
f"Processing {total_candidates} expansion candidates in chunks of {chunk_size}...",
359-
flush=True,
360-
)
361-
362-
with tqdm(
363-
total=total_candidates,
364-
desc="Looking for potential passage expansions",
365-
leave=False,
366-
) as pbar:
355+
with tqdm(total=total_candidates, desc="Expanding short passages", unit="passage") as pbar:
367356
for i in range(0, total_candidates, chunk_size):
368357
chunk = expansion_candidates[i : i + chunk_size]
369358
chunk_expansion_count = await _process_expansion_chunk(chunk, evaluator)
370359
expansion_count += chunk_expansion_count
360+
pbar.update(len(chunk))
371361

372-
# Add processed matches to final results
373362
for match, _, _ in chunk:
374363
final_matches.append(match)
375364

376-
pbar.update(len(chunk))
377-
378-
print(
379-
f"Looking for potential passage expansions: expanded {expansion_count} passages.",
380-
flush=True,
381-
)
365+
print(
366+
f"Expansion complete: {expansion_count}/{total_candidates} passages expanded.",
367+
flush=True,
368+
)
382369

383370
return final_matches
384371

@@ -408,8 +395,8 @@ async def _process_expansion_chunk(chunk: list[tuple[MergedGroup, int, int]], ev
408395
prev_pairs = [(exp["source_text"], exp["target_text"]) for exp in step1_prev_expansions]
409396
next_pairs = [(exp["source_text"], exp["target_text"]) for exp in step1_next_expansions]
410397

411-
prev_results = await evaluator.evaluate_batch(prev_pairs, batch_size=8)
412-
next_results = await evaluator.evaluate_batch(next_pairs, batch_size=8)
398+
prev_results = await evaluator.evaluate_batch(prev_pairs, batch_size=8, show_progress=False)
399+
next_results = await evaluator.evaluate_batch(next_pairs, batch_size=8, show_progress=False)
413400

414401
# --- Step 2: Determine winners and prepare next step ---
415402
step2_candidates = []
@@ -459,8 +446,8 @@ async def _process_expansion_chunk(chunk: list[tuple[MergedGroup, int, int]], ev
459446
)
460447
expansion_count += 1
461448

462-
# Prepare step 2 expansion
463-
step2_candidates.append(_prepare_expansion_step(original_match, step=2, direction=step2_direction))
449+
# Prepare step 2 expansion (+1 more sentence beyond what step 1 already added)
450+
step2_candidates.append(_prepare_expansion_step(original_match, step=1, direction=step2_direction))
464451
step2_directions.append(step2_direction)
465452
step2_match_map.append(original_match)
466453

@@ -469,7 +456,7 @@ async def _process_expansion_chunk(chunk: list[tuple[MergedGroup, int, int]], ev
469456

470457
# --- Step 3: Evaluate step 2 expansions ---
471458
step2_pairs = [(exp["source_text"], exp["target_text"]) for exp in step2_candidates]
472-
step2_results = await evaluator.evaluate_batch(step2_pairs, batch_size=8)
459+
step2_results = await evaluator.evaluate_batch(step2_pairs, batch_size=8, show_progress=False)
473460

474461
for i, (step2_score, _, _) in enumerate(step2_results):
475462
original_match = step2_match_map[i]

lib/textpair/vector_space_alignment/structures.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,22 @@ def __init__(self, matches: Iterable[MergedGroup]):
304304
def match_generator(self, new_matches):
305305
for match in new_matches:
306306
dump = ENCODER.encode(match)
307-
yield (self.count, dump)
307+
yield (
308+
self.count,
309+
dump,
310+
match.source.filename,
311+
match.target.filename,
312+
match.source.start_byte,
313+
match.source.start_byte - match.source.end_byte,
314+
match.target.start_byte,
315+
match.target.start_byte - match.target.end_byte,
316+
)
308317
self.count += 1
309318

310319
def extend(self, new_matches: Iterable[MergedGroup]):
311320
"""Add new matches to existing matches"""
312321
encoded_matches = self.match_generator(new_matches)
313-
self.cursor.executemany("INSERT INTO matches VALUES (?, ?)", encoded_matches)
322+
self.cursor.executemany("INSERT INTO matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)", encoded_matches)
314323

315324
def __save(self, matches):
316325
count = 0

lib/textpair_llm/textpair_llm/llm_evaluation.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def stop_server(self):
156156
self.server_process = None
157157

158158
async def evaluate_batch(
159-
self, passage_pairs: list[tuple[str, str]], batch_size: int = 8
159+
self, passage_pairs: list[tuple[str, str]], batch_size: int = 8, show_progress: bool = True
160160
) -> list[tuple[float, str, str]]:
161161
"""
162162
Evaluate multiple passage pairs concurrently
@@ -203,7 +203,7 @@ async def evaluate_single(session, source_text, target_text):
203203
results = []
204204
total_pairs = len(passage_pairs)
205205

206-
with tqdm(total=total_pairs, desc="LLM Evaluation", unit="pairs", leave=False) as pbar:
206+
with tqdm(total=total_pairs, desc="LLM Evaluation", unit="pairs", leave=False, disable=not show_progress) as pbar:
207207
for i in range(0, len(passage_pairs), batch_size):
208208
batch = passage_pairs[i : i + batch_size]
209209

@@ -425,21 +425,30 @@ def _parse_llm_response(self, response: str) -> tuple[float, str, str]:
425425
stance = stance_match.group(1).strip().capitalize()
426426
break
427427

428-
# Try multiple score patterns
428+
# Try multiple score patterns — ordered from most to least specific
429429
score_patterns = [
430430
r"Score:\s*([0-9]*\.?[0-9]+)", # "Score: 0.8"
431431
r"score:\s*([0-9]*\.?[0-9]+)", # "score: 0.8" (lowercase)
432-
r"([0-9]*\.?[0-9]+)", # Just a number anywhere
433432
]
434433

435434
score = 0.0
435+
score_found = False
436436
for pattern in score_patterns:
437437
score_match = re.search(pattern, response, re.IGNORECASE)
438438
if score_match:
439439
score = float(score_match.group(1))
440440
score = max(0.0, min(1.0, score)) # Clamp to valid range
441+
score_found = True
441442
break
442443

444+
# Last resort: find a decimal number (0.XX) that looks like a score,
445+
# but only match numbers with a decimal point to avoid grabbing years or counts
446+
if not score_found:
447+
fallback_match = re.search(r"\b(0\.\d+|1\.0+)\b", response)
448+
if fallback_match:
449+
score = float(fallback_match.group(1))
450+
score = max(0.0, min(1.0, score))
451+
443452
return score, reasoning, stance
444453

445454
except Exception as e:
@@ -521,13 +530,14 @@ def create_similarity_evaluation_prompt(source_text: str, target_text: str, cont
521530
First, determine if the passages address the same specific argument. Then, use the score guide below.
522531
523532
IMPORTANT: Direct agreement and direct disagreement on the exact same point are both forms of HIGH similarity.
524-
IMPORTANT: Avoid defaulting to the boundary scores of a category (like 0.40, 0.70, or 0.90). Use the full range to show nuance.
525533
526534
Score Guide:
527-
• 0.0 - 0.4: Different Subjects. The passages are about completely different topics.
528-
• > 0.4 to < 0.7: Shared Subject, Different Focus. The passages are about the same broad subject (e.g., the Roman Empire) but focus on different specific arguments or aspects (e.g., one is about military tactics, the other about trade policy).
529-
• 0.7 - 0.9: Shared Subject, Shared Focus. The passages address the exact same specific argument, question, or thesis. They are in direct conversation, whether they agree, disagree, or analyze it in parallel.
530-
• > 0.9 - 1.0: Paraphrase. The passages make the exact same point and have nearly identical meaning.
535+
• 0.00 - 0.40: Different Subjects. The passages are about completely different topics.
536+
• 0.41 - 0.69: Shared Subject, Different Focus. Same broad subject (e.g., the Roman Empire) but different specific arguments (e.g., military tactics vs. trade policy).
537+
• 0.70 - 0.79: Same Argument, Loose Connection. The passages address the same question or thesis but from meaningfully different angles, evidence bases, or time periods. The intellectual link is real but indirect.
538+
• 0.80 - 0.89: Same Argument, Clear Engagement. The passages directly address the same specific point with overlapping evidence, reasoning, or rhetorical framing. A reader would immediately see they are in conversation.
539+
• 0.90 - 0.95: Same Argument, Near-Paraphrase Framing. Very close in both content and rhetorical approach — similar structure, similar evidence, similar conclusions — but not word-for-word identical.
540+
• 0.96 - 1.00: True Paraphrase. The passages make the exact same point with nearly identical meaning. Only surface wording differs.
531541
532542
Your thought process:
533543
1. What is the broad subject of each passage?
@@ -536,7 +546,9 @@ def create_similarity_evaluation_prompt(source_text: str, target_text: str, cont
536546
- If they share the specific argument, do they Agree or Disagree?
537547
- If they only share the broad subject, mark as Neutral.
538548
- Otherwise, mark as Unrelated.
539-
4. Based on that, which score category do they fall into?
549+
4. Within the matching score category, calibrate precisely:
550+
- Low end of range: weaker fit for that category's description.
551+
- High end of range: strong fit, almost belongs in the next category up.
540552
541553
Provide your answer in this exact format:
542554
Reasoning: [Your step-by-step analysis - keep concise, 2-3 sentences]

0 commit comments

Comments
 (0)