Skip to content

Commit bc30aa4

Browse files
committed
Vectorize participant info computation (3-15x speedup)
## Summary - Replaces the O(N×G×C) per-participant Python loop in `_compute_participant_info_optimized` with bulk NumPy operations: matrix-wide vote counting (`np.sum` over axis) and per-group Pearson correlation via `P @ g` matrix multiply - Adds 31 unit tests covering vote counts, group correlations, edge cases (small groups, zero-std, NaN handling, missing members), and golden snapshot regression - Correlations now return Python `float` instead of `numpy.float64` - Includes a benchmark script (`scripts/benchmark_participant_info.py`) that runs old vs new on the same data ## Benchmark results Measured on real datasets (5 runs, median), old per-participant loop vs new vectorized: | Dataset | Size | Old | New | Speedup | |---------|------|-----|-----|---------| | vw | 69p × 125c × 4g | 0.011s | 0.001s | **14.6x** | | biodiversity | 536p × 314c × 2g | 0.047s | 0.006s | **8.1x** | | _(larger private datasets)_ | | | | **3–6x** | Speedup is higher on smaller datasets (loop overhead dominates) and lower on very large ones (matrix materialization dominates). Overall **3–15x** depending on size. ## Test plan - [x] 31 unit tests pass (pre-vectorization baseline established first, then re-run post) - [x] Golden snapshot regression passes for biodiversity + vw - [x] Full regression test suite passes (40/40) - [x] Benchmark run on all datasets including private (results above) - [x] Max correlation diff across all datasets: < 2e-15 🤖 Generated with [Claude Code](https://claude.com/claude-code) commit-id:ea747196
1 parent 1f86ab3 commit bc30aa4

12 files changed

Lines changed: 1022 additions & 277 deletions

delphi/docs/PLAN_DISCREPANCY_FIXES.md

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,24 @@
44

55
The Delphi Python math pipeline has 15 documented discrepancies with the Clojure reference implementation (see `deep-analysis-for-julien/07-discrepancies.md` and `deep-analysis-for-julien/09-fix-plan.md`). We need to fix them one-by-one with a TDD approach: **first extend the regression test to verify the discrepancy exists, then fix it, then verify the test passes**.
66

7-
Each fix will be a separate PR to keep reviews manageable. PRs will be **stacked** (each builds on the previous), since fixes are ordered by pipeline execution order — fixing upstream affects downstream. PRs should be clearly labeled as stacked with their dependency chain, so reviewers know the order. We can use `git rebase -i` to clean up commit history before merging to main.
7+
Each fix will be a separate PR to keep reviews manageable. PRs are **stacked** (each builds on the previous), since fixes are ordered by pipeline execution order — fixing upstream affects downstream. The stack is managed via `.claude/STACK` and the `/pr-stack` skill.
88

9-
**PR naming convention**: Clojure parity fix PRs use the title prefix `[Clj parity PR N]` (e.g., `[Clj parity PR 0] Per-discrepancy test infrastructure`, `[Clj parity PR 1] Fix D2: in-conv participant threshold`).
9+
**PR naming**: Titles use `[Stack N/M]` prefix (auto-managed by `.claude/skills/pr-stack/update-stack-titles.sh`). The descriptive part of the title should be self-explanatory.
1010

11-
**Prerequisite**: The current `kmeans_work` branch has changes from `edge`. These should be separated into their own PR(s) first, before we start the discrepancy fix PRs. The discrepancy fix PRs should be based on the cleaned-up branch.
11+
### Stack ↔ Plan Cross-Reference
1212

13-
**Action**: Before starting any fix, analyze the diff of `kmeans_work` vs `upstream/edge` to understand what's changed. Group those changes into logical PR(s) — e.g., test infrastructure improvements, doc updates, minor bug fixes. Each should be reviewable independently. The discrepancy fix PRs (PR 1+) then stack on top of this clean base.
13+
The full PR stack includes infrastructure PRs (Stack 1-7) followed by discrepancy fixes.
14+
This plan's "PR N" labels map to actual GitHub PRs as follows:
15+
16+
| Plan label | GitHub PR | Stack | Title |
17+
|-----------|-----------|-------|-------|
18+
| PR 0 (infra) | #2417#2420 | Stack 1-7 | Test cleanup, clustering, cold-start tooling, analysis docs |
19+
| PR 1 (D2) | #2421 | Stack 8/10 | Fix D2: in-conv participant threshold + D2c vote count source |
20+
| PR 2 (D4) | #2435 | Stack 9/10 | Fix D4: pseudocount formula |
21+
| (perf) | #2436 | Stack 10/10 | Speed up regression tests |
22+
| PR 3 (D9) ||| *Next: Fix D9 z-score thresholds* |
23+
24+
Future fix PRs will be appended to the stack as they're created.
1425

1526
### Session Continuity
1627

@@ -38,6 +49,7 @@ Because this work will span multiple Claude Code sessions, we maintain:
3849
- **All datasets, not just biodiversity**: Every fix must pass on ALL datasets. biodiversity is just one reference among many.
3950
- **Synthetic edge-case tests**: Every time we discover an edge case specific to one conversation, extract it into a synthetic unit test with made-up data (never real data from private datasets). These run fast and document the intent clearly.
4051
- **E2E awareness**: GitHub Actions has Cypress E2E tests (`cypress-tests.yml`) testing UI workflows, and `python-ci.yml` running pytest regression. The Cypress tests don't test math output values directly, but `python-ci.yml` will break if clustering/repness changes. Formula-level fixes (D4, D5, D6, D7, D8, D9) are pure computation — no E2E risk. Selection logic changes (D10, D11) and priority computation (D12) could affect what the TypeScript server returns. We decide case-by-case which PRs need E2E verification.
52+
- **Remove dead code after replacement**: When a function is replaced by a new implementation (e.g. vectorized version), the old function must be deleted and all callers updated — not left as dead code. Do this in the same PR or a follow-up, after benchmarks and tests confirm the replacement works.
4153

4254
### Datasets Available (sorted by size, smallest first)
4355

@@ -399,27 +411,33 @@ By this point, we should have good test coverage from all the per-discrepancy te
399411

400412
## Discrepancy Coverage Checklist
401413

402-
| ID | Discrepancy | PR | Status |
403-
|----|-------------|-----|--------|
404-
| D1 | PCA sign flips | PR 13 | Fix (sign consistency) |
405-
| D1b | Projection input | PR 13 | Fix with D1 |
406-
| D2 | In-conv threshold | **PR 1** | **DONE**|
407-
| D2b | Base-cluster sort order | **PR 1** | **DONE**|
408-
| D2c | Vote count source (raw vs filtered matrix) | **PR 1** | **DONE**|
409-
| D2d | In-conv monotonicity (once in, always in) | **PR 1** | **DONE** ✓ (5 guard tests, T1-T5) |
410-
| D3 | K-smoother buffer | PR 10 | Fix |
411-
| D4 | Pseudocount formula | **PR 2** | **DONE**|
412-
| D5 | Proportion test | PR 4 | Fix |
413-
| D6 | Two-proportion test | PR 5 | Fix |
414-
| D7 | Repness metric | PR 6 | Fix (with flag for old formula) |
415-
| D8 | Finalize cmt stats | PR 7 | Fix |
416-
| D9 | Z-score thresholds | **PR 3** | Fix |
417-
| D10 | Rep comment selection | PR 8 | Fix (with legacy env var) |
418-
| D11 | Consensus selection | PR 9 | Fix (with legacy env var) |
419-
| D12 | Comment priorities | PR 11 | Fix (implement from scratch) |
420-
| D13 | Subgroup clustering || **Deferred** (unused) |
421-
| D14 | Large conv optimization || **Deferred** (Python fast enough) |
422-
| D15 | Moderation handling | PR 12 | Fix |
414+
| ID | Discrepancy | Plan PR | GitHub PR | Status |
415+
|----|-------------|---------|-----------|--------|
416+
| D1 | PCA sign flips | PR 13 || Fix (sign consistency) |
417+
| D1b | Projection input | PR 13 || Fix with D1 |
418+
| D2 | In-conv threshold | **PR 1** | **#2421** | **DONE**|
419+
| D2b | Base-cluster sort order | **PR 1** | **#2421** | **DONE**|
420+
| D2c | Vote count source (raw vs filtered matrix) | **PR 1** | **#2421** | **DONE**|
421+
| D2d | In-conv monotonicity (once in, always in) | **PR 1** | **#2421** | **DONE** ✓ (5 guard tests, T1-T5) |
422+
| D3 | K-smoother buffer | PR 10 || Fix |
423+
| D4 | Pseudocount formula | **PR 2** | **#2435** | **DONE**|
424+
| D5 | Proportion test | PR 4 || Fix |
425+
| D6 | Two-proportion test | PR 5 || Fix |
426+
| D7 | Repness metric | PR 6 || Fix (with flag for old formula) |
427+
| D8 | Finalize cmt stats | PR 7 || Fix |
428+
| D9 | Z-score thresholds | PR 3 || Fix (next) |
429+
| D10 | Rep comment selection | PR 8 || Fix (with legacy env var) |
430+
| D11 | Consensus selection | PR 9 || Fix (with legacy env var) |
431+
| D12 | Comment priorities | PR 11 || Fix (implement from scratch) |
432+
| D13 | Subgroup clustering ||| **Deferred** (unused) |
433+
| D14 | Large conv optimization ||| **Deferred** (Python fast enough) |
434+
| D15 | Moderation handling | PR 12 || Fix |
435+
436+
### Non-discrepancy PRs in the stack
437+
438+
| GitHub PR | Stack | Description |
439+
|-----------|-------|-------------|
440+
| #2436 | 10/10 | Speed up regression tests (benchmark off, skip intermediate stages) |
423441

424442
---
425443

delphi/notebooks/run_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def check_environment():
4545

4646
# Import polismath modules
4747
from polismath.conversation.conversation import Conversation
48-
from polismath.pca_kmeans_rep.repness import conv_repness, participant_stats
48+
from polismath.pca_kmeans_rep.repness import conv_repness
4949
from polismath.pca_kmeans_rep.corr import compute_correlation
5050

5151
def load_votes(votes_path):

delphi/polismath/conversation/conversation.py

Lines changed: 75 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
kmeans_sklearn,
2121
calculate_silhouette_sklearn
2222
)
23-
from polismath.pca_kmeans_rep.repness import conv_repness, participant_stats
23+
from polismath.pca_kmeans_rep.repness import conv_repness
2424
from polismath.pca_kmeans_rep.corr import compute_correlation
2525

2626

@@ -792,114 +792,101 @@ def _compute_participant_info_optimized(self, vote_matrix: pd.DataFrame, group_c
792792

793793
# OPTIMIZATION 3: Precompute group vote matrices and average votes
794794

795-
# Precompute group vote matrices and their valid comment masks
796-
group_vote_matrices = {}
795+
# Precompute group average votes and valid comment masks
797796
group_avg_votes = {}
798797
group_valid_masks = {}
799-
798+
800799
for group_id, member_indices in group_member_indices.items():
801800
if len(member_indices) >= 3: # Only calculate for groups with enough members
802801
# Extract the group vote matrix
803802
group_vote_matrix = matrix_values[member_indices, :]
804-
group_vote_matrices[group_id] = group_vote_matrix
805-
803+
806804
# Calculate average votes per comment for this group
807805
group_avg_votes[group_id] = np.mean(group_vote_matrix, axis=0)
808806

809807
# Precompute which comments have at least 3 votes from this group
810808
group_valid_masks[group_id] = np.sum(group_vote_matrix != 0, axis=0) >= 3
811809

812-
# OPTIMIZATION 4: Use vectorized operations for participant stats
813-
810+
# VECTORIZED: Compute vote counts for ALL participants at once
811+
814812
process_start = time.time()
815-
batch_start = time.time()
816-
817-
for p_idx, participant_id in enumerate(vote_matrix.index):
818-
if p_idx >= matrix_values.shape[0]:
813+
814+
n_agree_all = np.sum(matrix_values > 0, axis=1) # (N,)
815+
n_disagree_all = np.sum(matrix_values < 0, axis=1) # (N,)
816+
n_pass_all = np.sum(matrix_values == 0, axis=1) # (N,)
817+
n_votes_all = n_agree_all + n_disagree_all # (N,)
818+
819+
# Mask: participants with at least one real vote
820+
has_votes = n_votes_all > 0 # (N,) bool
821+
822+
# VECTORIZED: Compute per-group correlations for ALL participants at once
823+
# Store as {group_id: corr_array} where corr_array is (N,)
824+
group_corr_arrays = {}
825+
826+
for group_id, member_indices in group_member_indices.items():
827+
if len(member_indices) < 3 or group_id not in group_avg_votes:
828+
# All correlations default to 0.0
829+
group_corr_arrays[group_id] = np.zeros(participant_count)
819830
continue
820-
821-
# Print progress for large participant sets
822-
if participant_count > 100 and p_idx % 100 == 0:
823-
now = time.time()
824-
elapsed = now - process_start
825-
batch_time = now - batch_start
826-
batch_start = now
827-
percent = (p_idx / participant_count) * 100
828-
logger.info(f"Processed {p_idx}/{participant_count} participants ({percent:.1f}%) - " +
829-
f"Elapsed: {elapsed:.2f}s, Batch: {batch_time:.4f}s")
830-
831-
# Get participant votes
832-
participant_votes = matrix_values[p_idx, :]
833-
834-
# Count votes using vectorized operations
835-
n_agree = np.sum(participant_votes > 0)
836-
n_disagree = np.sum(participant_votes < 0)
837-
n_pass = np.sum(participant_votes == 0)
838-
n_votes = n_agree + n_disagree
839-
840-
# Skip participants with no votes
841-
if n_votes == 0:
831+
832+
valid_mask = group_valid_masks[group_id]
833+
n_valid = int(np.sum(valid_mask))
834+
835+
if n_valid < 3:
836+
group_corr_arrays[group_id] = np.zeros(participant_count)
842837
continue
843-
844-
# Find participant's group using precomputed mapping
845-
participant_group = ptpt_group_map.get(participant_id)
846-
847-
# OPTIMIZATION 5: Efficient group correlation calculation
848-
849-
# Calculate agreement with each group - optimized version
850-
group_agreements = {}
851-
852-
for group_id, member_indices in group_member_indices.items():
853-
if len(member_indices) < 3:
854-
# Skip groups with too few members
855-
group_agreements[group_id] = 0.0
856-
continue
857-
858-
if group_id not in group_avg_votes or group_id not in group_valid_masks:
859-
group_agreements[group_id] = 0.0
860-
continue
861-
862-
# Use precomputed data
863-
g_votes = group_avg_votes[group_id]
864-
valid_mask = group_valid_masks[group_id]
865-
866-
if np.sum(valid_mask) >= 3: # At least 3 valid comments
867-
# Extract only valid comment votes
868-
p_votes = participant_votes[valid_mask]
869-
g_votes_valid = g_votes[valid_mask]
870-
871-
# Fast correlation calculation
872-
p_std = np.std(p_votes)
873-
g_std = np.std(g_votes_valid)
874-
875-
if p_std > 0 and g_std > 0:
876-
# Use numpy's built-in correlation (faster and more numerically stable)
877-
correlation = np.corrcoef(p_votes, g_votes_valid)[0, 1]
878-
879-
if not np.isnan(correlation):
880-
group_agreements[group_id] = correlation
881-
else:
882-
group_agreements[group_id] = 0.0
883-
else:
884-
group_agreements[group_id] = 0.0
885-
else:
886-
group_agreements[group_id] = 0.0
887-
888-
# Store participant stats
838+
839+
# P: all participants' votes on valid comments — (N, n_valid)
840+
P = matrix_values[:, valid_mask]
841+
# g: group average on valid comments — (n_valid,)
842+
g = group_avg_votes[group_id][valid_mask]
843+
844+
p_mean = P.mean(axis=1) # (N,)
845+
g_mean = g.mean() # scalar
846+
p_std = P.std(axis=1) # (N,)
847+
g_std = g.std() # scalar
848+
849+
if g_std == 0:
850+
group_corr_arrays[group_id] = np.zeros(participant_count)
851+
continue
852+
853+
# Pearson correlation: (mean(P*g) - mean(P)*mean(g)) / (std(P)*std(g))
854+
cross_mean = (P @ g) / n_valid # (N,)
855+
856+
# np.where evaluates both branches; suppress divide-by-zero for p_std==0
857+
with np.errstate(invalid='ignore', divide='ignore'):
858+
corr = np.where(
859+
p_std > 0,
860+
(cross_mean - p_mean * g_mean) / (p_std * g_std),
861+
0.0,
862+
)
863+
corr = np.nan_to_num(corr, nan=0.0)
864+
group_corr_arrays[group_id] = corr
865+
866+
# Assemble result dicts (zero computation — just indexing)
867+
group_ids = list(group_member_indices.keys())
868+
869+
for p_idx, participant_id in enumerate(vote_matrix.index):
870+
if not has_votes[p_idx]:
871+
continue
872+
889873
result['stats'][participant_id] = {
890-
'n_agree': int(n_agree),
891-
'n_disagree': int(n_disagree),
892-
'n_pass': int(n_pass),
893-
'n_votes': int(n_votes),
894-
'group': participant_group,
895-
'group_correlations': group_agreements
874+
'n_agree': int(n_agree_all[p_idx]),
875+
'n_disagree': int(n_disagree_all[p_idx]),
876+
'n_pass': int(n_pass_all[p_idx]),
877+
'n_votes': int(n_votes_all[p_idx]),
878+
'group': ptpt_group_map.get(participant_id),
879+
'group_correlations': {
880+
gid: float(group_corr_arrays[gid][p_idx])
881+
for gid in group_ids
882+
}
896883
}
897-
884+
898885
total_time = time.time() - start_time
899886
process_time = time.time() - process_start
900887
logger.info(f"Participant stats completed in {total_time:.2f}s (preparation: {prep_time:.2f}s, processing: {process_time:.2f}s)")
901888
logger.info(f"Processed {len(result['stats'])} participants with {len(group_clusters)} groups")
902-
889+
903890
return result
904891

905892
def _compute_participant_info(self) -> None:

delphi/polismath/pca_kmeans_rep/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@
1010

1111
from polismath.pca_kmeans_rep.pca import pca_project_dataframe
1212
from polismath.pca_kmeans_rep.clusters import cluster_dataframe, Cluster
13-
from polismath.pca_kmeans_rep.repness import conv_repness, participant_stats
13+
from polismath.pca_kmeans_rep.repness import conv_repness
1414
from polismath.pca_kmeans_rep.corr import compute_correlation
1515

1616
__all__ = [
1717
'pca_project_dataframe',
1818
'cluster_dataframe',
1919
'Cluster',
2020
'conv_repness',
21-
'participant_stats',
2221
'compute_correlation',
2322
]

0 commit comments

Comments
 (0)