Skip to content

Commit 03e83b0

Browse files
committed
various fixes and documentation
1 parent 20be927 commit 03e83b0

2 files changed

Lines changed: 187 additions & 69 deletions

File tree

lib/textpair_graph/textpair_graph/build_graph_model.py

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import gc
44
import os
55
import sys
6+
from collections import defaultdict
67

78
import faiss
89
import lz4.frame
9-
import networkx as nx
1010
import numpy as np
1111
import orjson
1212
import torch
@@ -40,7 +40,7 @@
4040
BATCH_SIZE = 4096 # For SBERT encoding
4141

4242

43-
def build_alignment_data(alignments_file: str, alignment_counts: int, sbert_model_name: str):
43+
def build_alignment_data(alignments_file: str, alignment_counts: int, sbert_model_name: str) -> dict:
4444
"""Preprocess alignments and encode passages."""
4545

4646
print("Building author mapping...", end=' ')
@@ -147,7 +147,7 @@ def build_alignment_data(alignments_file: str, alignment_counts: int, sbert_mode
147147
'num_authors': len(author_to_id)
148148
}
149149

150-
def cluster_alignments(data, output_path, alignments_file, alignment_counts):
150+
def cluster_alignments(data: dict, output_path: str, alignment_counts: int) -> tuple[np.ndarray, np.ndarray]:
151151
"""
152152
Cluster all alignments using HDBSCAN on SBERT embeddings.
153153
Returns cluster labels for direct lookup at runtime.
@@ -156,16 +156,57 @@ def cluster_alignments(data, output_path, alignments_file, alignment_counts):
156156
sbert_dim = all_embeddings.shape[1]
157157

158158
low_dim = 32
159-
n_neighbors = min(100, max(15, alignment_counts // 10000))
159+
n_neighbors = min(30, max(15, alignment_counts // 10000))
160160
print(f"Reducing embeddings dimensionality ({sbert_dim} to {low_dim}D)...", end=' ')
161-
reducer_low_dim = UMAP(
162-
n_components=low_dim,
163-
n_neighbors=n_neighbors,
164-
min_dist=0.0,
165-
metric='cosine',
166-
random_state=42
167-
)
168-
low_dim_embeddings = reducer_low_dim.fit_transform(all_embeddings)
161+
162+
# Optimization for large datasets: Train UMAP on a subset, transform the rest
163+
umap_train_limit = 1_000_000
164+
n_total_embeddings = all_embeddings.shape[0]
165+
166+
if n_total_embeddings > umap_train_limit:
167+
print(f"\nLarge dataset ({n_total_embeddings:,} items). Training UMAP on random {umap_train_limit:,} subset...")
168+
169+
# Random sample for training
170+
rng = np.random.default_rng(42)
171+
train_indices = rng.choice(n_total_embeddings, size=umap_train_limit, replace=False)
172+
train_indices.sort()
173+
174+
# Load training data into memory
175+
X_train = all_embeddings[train_indices]
176+
177+
reducer_low_dim = UMAP(
178+
n_components=low_dim,
179+
n_neighbors=n_neighbors,
180+
min_dist=0.0,
181+
metric='cosine',
182+
random_state=42
183+
)
184+
reducer_low_dim.fit(X_train)
185+
186+
del X_train
187+
gc.collect()
188+
189+
print(f"Transforming all {n_total_embeddings:,} embeddings in batches...")
190+
low_dim_embeddings = np.zeros((n_total_embeddings, low_dim), dtype=np.float32)
191+
192+
batch_size = 1000000
193+
n_batches = (n_total_embeddings + batch_size - 1) // batch_size
194+
195+
for i in tqdm(range(0, n_total_embeddings, batch_size), desc="UMAP Transform", total=n_batches):
196+
end_idx = min(i + batch_size, n_total_embeddings)
197+
batch = all_embeddings[i:end_idx]
198+
low_dim_embeddings[i:end_idx] = reducer_low_dim.transform(batch)
199+
200+
else:
201+
reducer_low_dim = UMAP(
202+
n_components=low_dim,
203+
n_neighbors=n_neighbors,
204+
min_dist=0.0,
205+
metric='cosine',
206+
random_state=42
207+
)
208+
low_dim_embeddings = reducer_low_dim.fit_transform(all_embeddings)
209+
169210
print("done.")
170211
gc.collect()
171212

@@ -187,10 +228,10 @@ def cluster_alignments(data, output_path, alignments_file, alignment_counts):
187228
test_embeddings = low_dim_embeddings[test_indices]
188229

189230
print(f"Clustering with HDBSCAN on {len(train_embeddings):,} training passages...", end=' ')
190-
min_cluster_size = max(15, int(0.005 * len(train_embeddings)))
231+
min_cluster_size = max(15, int(0.001 * len(train_embeddings)))
191232
clusterer = HDBSCAN(
192233
min_cluster_size=min_cluster_size,
193-
min_samples=2,
234+
min_samples=15,
194235
metric='euclidean',
195236
cluster_selection_method='eom',
196237
prediction_data=True
@@ -232,13 +273,13 @@ def cluster_alignments(data, output_path, alignments_file, alignment_counts):
232273
gc.collect()
233274
else:
234275
print(f"Clustering all {n_total:,} passages with HDBSCAN...", end=' ')
235-
min_cluster_size = max(15, int(0.005 * n_total))
276+
min_cluster_size = max(15, int(0.0001 * n_total))
236277
clusterer = HDBSCAN(
237278
min_cluster_size=min_cluster_size,
238-
min_samples=2,
279+
min_samples=15,
239280
metric='euclidean',
240281
cluster_selection_method='eom',
241-
cluster_selection_epsilon=0.3,
282+
cluster_selection_epsilon=0.0,
242283
prediction_data=True
243284
)
244285
cluster_labels = clusterer.fit_predict(low_dim_embeddings)
@@ -254,9 +295,6 @@ def cluster_alignments(data, output_path, alignments_file, alignment_counts):
254295
n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
255296
print(f"Final: {n_clusters} clusters with {n_noise} total outliers ({100*n_noise/len(cluster_labels):.1f}%)")
256297

257-
258-
259-
260298
# Handle noise cluster: treat each noise point as its own singleton cluster
261299
noise_indices = np.where(cluster_labels == -1)[0]
262300

@@ -305,7 +343,7 @@ def cluster_alignments(data, output_path, alignments_file, alignment_counts):
305343
metric='euclidean',
306344
random_state=42
307345
)
308-
embeddings_2d = reducer_2d.fit_transform(low_dim_embeddings)
346+
embeddings_2d: np.ndarray = reducer_2d.fit_transform(low_dim_embeddings) # type: ignore
309347

310348
del reducer_2d
311349
gc.collect()
@@ -333,22 +371,21 @@ def cluster_alignments(data, output_path, alignments_file, alignment_counts):
333371
with open(os.path.join(output_path, 'cluster_metadata.json'), 'wb') as f:
334372
f.write(orjson.dumps(metadata))
335373

336-
return cluster_labels
374+
return modified_cluster_labels, embeddings_2d
337375

338376

339-
def build_precomputed_api_graph(alignments_file: str, output_path: str):
377+
def build_precomputed_api_graph(alignments_file: str, output_path: str, author_to_id: dict,
378+
cluster_labels_modified: np.ndarray, embeddings_2d: np.ndarray,
379+
alignment_counts: int) -> None:
340380
"""
341381
Build precomputed graph for API in the same format as get_semantic_graph_data.
342382
343383
Creates precomputed_graph_api.json with (author, cluster) pair nodes and edges.
344384
"""
345-
print("Loading data...")
346-
cluster_labels_modified = np.load(os.path.join(output_path, 'cluster_labels_modified.npy'))
347-
embeddings_umap_2d = np.load(os.path.join(output_path, 'embeddings_umap_2d.npy'))
348-
cluster_similarity = np.load(os.path.join(output_path, 'cluster_similarity_matrix.npy'))
385+
print("Building precomputed graph...")
349386

350-
with open(os.path.join(output_path, 'author_to_id.json'), 'rb') as f:
351-
author_to_id = orjson.loads(f.read())
387+
# Load only cluster similarity matrix and metadata
388+
cluster_similarity = np.load(os.path.join(output_path, 'cluster_similarity_matrix.npy'))
352389

353390
with open(os.path.join(output_path, 'cluster_metadata.json'), 'rb') as f:
354391
metadata = orjson.loads(f.read())
@@ -363,10 +400,9 @@ def build_precomputed_api_graph(alignments_file: str, output_path: str):
363400
print(f" {alignment_counts} alignments")
364401

365402
print("\nEnumerating (author, cluster) pairs from alignments...")
366-
from collections import defaultdict
367403

368404
pair_passage_counts = defaultdict(int)
369-
pair_embeddings_2d = defaultdict(list)
405+
pair_position_sums = defaultdict(lambda: np.zeros(2, dtype=np.float64))
370406

371407
alignment_idx = 0
372408
with lz4.frame.open(alignments_file, "rb") as f:
@@ -376,12 +412,12 @@ def build_precomputed_api_graph(alignments_file: str, output_path: str):
376412
target_author_id = author_to_id[alignment["target_author"]]
377413

378414
cluster_id = int(cluster_labels_modified[alignment_idx])
379-
embedding_2d = embeddings_umap_2d[alignment_idx]
415+
embedding_2d = embeddings_2d[alignment_idx]
380416

381417
for author_id in [source_author_id, target_author_id]:
382418
pair_key = (author_id, cluster_id)
383419
pair_passage_counts[pair_key] += 1
384-
pair_embeddings_2d[pair_key].append(embedding_2d)
420+
pair_position_sums[pair_key] += embedding_2d
385421

386422
alignment_idx += 1
387423

@@ -390,9 +426,11 @@ def build_precomputed_api_graph(alignments_file: str, output_path: str):
390426
# Calculate mean 2D positions for each pair
391427
print("\nComputing mean 2D positions for each pair...")
392428
pair_positions = {}
393-
for pair_key, embeddings_list in pair_embeddings_2d.items():
394-
mean_position = np.mean(embeddings_list, axis=0)
395-
pair_positions[pair_key] = mean_position
429+
for pair_key, count in pair_passage_counts.items():
430+
pair_positions[pair_key] = pair_position_sums[pair_key] / count
431+
432+
del pair_position_sums
433+
gc.collect()
396434

397435
# Build precomputed graph in API format (matches get_semantic_graph_data output)
398436
print("\nCreating precomputed graph for API...")
@@ -526,10 +564,11 @@ def main():
526564
f.write(orjson.dumps(data['author_to_id']))
527565

528566
# Cluster alignments by content similarity
529-
cluster_labels = cluster_alignments(data, output_path, alignments_file, alignment_counts)
567+
modified_cluster_labels, embeddings_2d = cluster_alignments(data, output_path, alignment_counts)
530568

531569
# Build precomputed graph for API
532-
build_precomputed_api_graph(alignments_file, output_path)
570+
build_precomputed_api_graph(alignments_file, output_path, data['author_to_id'],
571+
modified_cluster_labels, embeddings_2d, alignment_counts)
533572

534573
print("\nThematic Identify Graph done.")
535574

0 commit comments

Comments
 (0)