Skip to content

Commit 591e196

Browse files
committed
Speed up regression tests
## Summary - Default `benchmark=False` in `compare_with_golden()` — benchmark mode ran the pipeline 3x for timing statistics, unnecessary for correctness checks. The `regression_comparer.py` script already had `--benchmark` as opt-in, so this aligns the default. - Add `skip_intermediate_stages` parameter to `compute_all_stages()` — `test_conversation_regression` now skips stages 1-4 (empty, load-only, PCA-only, PCA+clustering) since it only checks `overall_match`. `test_conversation_stages_individually` still runs all stages for granular failure detection. ### Measured speedup on one of the large private test conversations | Test | Before | After | Speedup | |------|--------|-------|---------| | `test_conversation_regression` | 317s | 23s | **13.9x** | | `test_conversation_stages_individually` | 60s | 32s | **1.9x** | The regression test's ~14x speedup comes from two combined effects: no longer running the pipeline 3x (benchmark), and skipping 4 redundant intermediate stages. ## Test plan - [x] All 9 public regression tests pass (vw + biodiversity) - [x] Private dataset tests pass (`--include-local`) - [x] Timing verified on large private dataset 🤖 Generated with [Claude Code](https://claude.com/claude-code) commit-id:f39f3218
1 parent b437500 commit 591e196

3 files changed

Lines changed: 102 additions & 74 deletions

File tree

delphi/polismath/regression/comparer.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,20 @@ def __init__(
5959
# E.g., {0: 1, 1: -1} means PC1 unchanged, PC2 flipped
6060
self._pca_sign_flips: Dict[str, Dict[int, int]] = {}
6161

62-
def compare_with_golden(self, dataset_name: str, benchmark: bool = True) -> Dict:
62+
def compare_with_golden(
63+
self,
64+
dataset_name: str,
65+
benchmark: bool = False,
66+
skip_intermediate_stages: bool = False,
67+
) -> Dict:
6368
"""
6469
Compare current implementation with golden snapshot.
6570
6671
Args:
6772
dataset_name: Name of the dataset ('biodiversity' or 'vw')
68-
benchmark: If True, compare timing information (default: True)
73+
benchmark: If True, compare timing information (default: False)
74+
skip_intermediate_stages: If True, skip stages 1-4 and only compute
75+
full recompute + data export. Saves time for large datasets.
6976
7077
Returns:
7178
Dictionary containing comparison results
@@ -139,13 +146,17 @@ def compare_with_golden(self, dataset_name: str, benchmark: bool = True) -> Dict
139146
if benchmark:
140147
logger.info("Computing all stages with benchmarking...")
141148
current_results = compute_all_stages_with_benchmark(
142-
dataset_name, votes_dict, metadata["fixed_timestamp"]
149+
dataset_name, votes_dict, metadata["fixed_timestamp"],
150+
skip_intermediate_stages=skip_intermediate_stages,
143151
)
144152
current_stages = current_results["stages"]
145153
current_timing_stats = current_results["timing_stats"]
146154
else:
147155
logger.info("Computing all stages...")
148-
current_results = compute_all_stages(dataset_name, votes_dict, metadata["fixed_timestamp"])
156+
current_results = compute_all_stages(
157+
dataset_name, votes_dict, metadata["fixed_timestamp"],
158+
skip_intermediate_stages=skip_intermediate_stages,
159+
)
149160
current_stages = current_results["stages"]
150161
current_timing_stats = {}
151162

delphi/polismath/regression/utils.py

Lines changed: 83 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ def compute_file_md5(filepath: str) -> str:
3737
return "error_computing_md5"
3838

3939

40-
def compute_all_stages(dataset_name: str, votes_dict: Dict, fixed_timestamp: int) -> Dict[str, Dict[str, Any]]:
40+
def compute_all_stages(
41+
dataset_name: str,
42+
votes_dict: Dict,
43+
fixed_timestamp: int,
44+
skip_intermediate_stages: bool = False,
45+
) -> Dict[str, Dict[str, Any]]:
4146
"""
4247
Compute all conversation stages with timing information.
4348
@@ -50,6 +55,9 @@ def compute_all_stages(dataset_name: str, votes_dict: Dict, fixed_timestamp: int
5055
votes_dict: Dictionary containing votes data with format:
5156
{'votes': [...], 'lastVoteTimestamp': timestamp}
5257
fixed_timestamp: Fixed timestamp for reproducibility
58+
skip_intermediate_stages: If True, skip stages 1-4 (empty, load-only,
59+
PCA-only, PCA+clustering) and only compute the full recompute and
60+
data export. Saves significant time for large datasets.
5361
5462
Returns:
5563
Dictionary with two keys:
@@ -59,73 +67,74 @@ def compute_all_stages(dataset_name: str, votes_dict: Dict, fixed_timestamp: int
5967
stages = {}
6068
timings = {}
6169

62-
# Stage 1: Empty conversation (with fixed timestamp)
63-
start_time = time.perf_counter()
64-
conv_empty = Conversation(dataset_name, last_updated=fixed_timestamp)
65-
timings["empty"] = time.perf_counter() - start_time
66-
stages["empty"] = conv_empty.to_dict()
67-
68-
# Stage 2: After loading votes (no recompute)
69-
conv = Conversation(dataset_name, last_updated=fixed_timestamp)
70-
start_time = time.perf_counter()
71-
conv = conv.update_votes(votes_dict, recompute=False)
72-
timings["after_load_no_compute"] = time.perf_counter() - start_time
73-
74-
# Validation: Ensure votes were actually loaded
75-
if conv.participant_count == 0 or conv.comment_count == 0:
76-
raise ValueError(
77-
f"Failed to load votes! participant_count={conv.participant_count}, "
78-
f"comment_count={conv.comment_count}"
79-
)
70+
if not skip_intermediate_stages:
71+
# Stage 1: Empty conversation (with fixed timestamp)
72+
start_time = time.perf_counter()
73+
conv_empty = Conversation(dataset_name, last_updated=fixed_timestamp)
74+
timings["empty"] = time.perf_counter() - start_time
75+
stages["empty"] = conv_empty.to_dict()
8076

81-
stages["after_load_no_compute"] = conv.to_dict()
82-
83-
# DEBUG: Capture the matrix that goes into PCA (only when DEBUG logging is enabled)
84-
if logger.isEnabledFor(logging.DEBUG):
85-
debug_info = {}
86-
try:
87-
# Get the clean matrix that PCA will use
88-
if hasattr(conv, '_get_clean_matrix'):
89-
clean_matrix = conv._get_clean_matrix()
90-
# Save first 5x5 section of the matrix for debugging
91-
if not clean_matrix.empty:
92-
debug_info["pca_input_matrix_sample"] = {
93-
"shape": list(clean_matrix.shape),
94-
"rows_first_10": list(clean_matrix.index[:10]),
95-
"cols_first_10": list(clean_matrix.columns[:10]),
96-
"sample_5x5": clean_matrix.iloc[:5, :5].to_dict(),
97-
"dtype": str(clean_matrix.dtypes.iloc[0] if len(clean_matrix.dtypes) > 0 else "unknown")
98-
}
99-
# Check for NaN values
100-
nan_info = {
101-
"total_cells": clean_matrix.size,
102-
"nan_count": clean_matrix.isna().sum().sum(),
103-
"nan_percentage": (clean_matrix.isna().sum().sum() / clean_matrix.size * 100) if clean_matrix.size > 0 else 0
104-
}
105-
debug_info["nan_info"] = nan_info
106-
107-
# Save debug info to .test_outputs/debug directory
108-
debug_dir = Path(__file__).parent.parent / ".test_outputs" / "debug"
109-
debug_dir.mkdir(parents=True, exist_ok=True)
110-
debug_path = debug_dir / f"pca_debug_{dataset_name}.json"
111-
with open(debug_path, "w") as f:
112-
json.dump(debug_info, f, indent=2, default=str)
113-
logger.debug(f"Saved PCA debug info to {debug_path}")
114-
except Exception as e:
115-
logger.error(f"Debug capture failed: {e}")
116-
117-
# Stage 3: After PCA computation only
118-
start_time = time.perf_counter()
119-
conv._compute_pca()
120-
timings["after_pca"] = time.perf_counter() - start_time
121-
stages["after_pca"] = conv.to_dict()
77+
# Stage 2: After loading votes (no recompute)
78+
conv = Conversation(dataset_name, last_updated=fixed_timestamp)
79+
start_time = time.perf_counter()
80+
conv = conv.update_votes(votes_dict, recompute=False)
81+
timings["after_load_no_compute"] = time.perf_counter() - start_time
82+
83+
# Validation: Ensure votes were actually loaded
84+
if conv.participant_count == 0 or conv.comment_count == 0:
85+
raise ValueError(
86+
f"Failed to load votes! participant_count={conv.participant_count}, "
87+
f"comment_count={conv.comment_count}"
88+
)
89+
90+
stages["after_load_no_compute"] = conv.to_dict()
91+
92+
# DEBUG: Capture the matrix that goes into PCA (only when DEBUG logging is enabled)
93+
if logger.isEnabledFor(logging.DEBUG):
94+
debug_info = {}
95+
try:
96+
# Get the clean matrix that PCA will use
97+
if hasattr(conv, '_get_clean_matrix'):
98+
clean_matrix = conv._get_clean_matrix()
99+
# Save first 5x5 section of the matrix for debugging
100+
if not clean_matrix.empty:
101+
debug_info["pca_input_matrix_sample"] = {
102+
"shape": list(clean_matrix.shape),
103+
"rows_first_10": list(clean_matrix.index[:10]),
104+
"cols_first_10": list(clean_matrix.columns[:10]),
105+
"sample_5x5": clean_matrix.iloc[:5, :5].to_dict(),
106+
"dtype": str(clean_matrix.dtypes.iloc[0] if len(clean_matrix.dtypes) > 0 else "unknown")
107+
}
108+
# Check for NaN values
109+
nan_info = {
110+
"total_cells": clean_matrix.size,
111+
"nan_count": clean_matrix.isna().sum().sum(),
112+
"nan_percentage": (clean_matrix.isna().sum().sum() / clean_matrix.size * 100) if clean_matrix.size > 0 else 0
113+
}
114+
debug_info["nan_info"] = nan_info
115+
116+
# Save debug info to .test_outputs/debug directory
117+
debug_dir = Path(__file__).parent.parent / ".test_outputs" / "debug"
118+
debug_dir.mkdir(parents=True, exist_ok=True)
119+
debug_path = debug_dir / f"pca_debug_{dataset_name}.json"
120+
with open(debug_path, "w") as f:
121+
json.dump(debug_info, f, indent=2, default=str)
122+
logger.debug(f"Saved PCA debug info to {debug_path}")
123+
except Exception as e:
124+
logger.error(f"Debug capture failed: {e}")
125+
126+
# Stage 3: After PCA computation only
127+
start_time = time.perf_counter()
128+
conv._compute_pca()
129+
timings["after_pca"] = time.perf_counter() - start_time
130+
stages["after_pca"] = conv.to_dict()
122131

123-
# Stage 4: After PCA + clustering
124-
start_time = time.perf_counter()
125-
conv._compute_pca()
126-
conv._compute_clusters()
127-
timings["after_clustering"] = time.perf_counter() - start_time
128-
stages["after_clustering"] = conv.to_dict()
132+
# Stage 4: After PCA + clustering
133+
start_time = time.perf_counter()
134+
conv._compute_pca()
135+
conv._compute_clusters()
136+
timings["after_clustering"] = time.perf_counter() - start_time
137+
stages["after_clustering"] = conv.to_dict()
129138

130139
# Stage 5: Full recompute (includes repness and participant_info)
131140
conv_full = Conversation(dataset_name, last_updated=fixed_timestamp)
@@ -159,7 +168,8 @@ def compute_all_stages_with_benchmark(
159168
dataset_name: str,
160169
votes_dict: Dict,
161170
fixed_timestamp: int,
162-
n_runs: int = 3
171+
n_runs: int = 3,
172+
skip_intermediate_stages: bool = False,
163173
) -> Dict[str, Any]:
164174
"""
165175
Compute all conversation stages multiple times and collect timing statistics.
@@ -173,6 +183,8 @@ def compute_all_stages_with_benchmark(
173183
votes_dict: Dictionary containing votes data
174184
fixed_timestamp: Fixed timestamp for reproducibility
175185
n_runs: Number of times to run the computation (default: 3)
186+
skip_intermediate_stages: If True, skip stages 1-4 (passed through to
187+
compute_all_stages).
176188
177189
Returns:
178190
Dictionary with:
@@ -187,7 +199,10 @@ def compute_all_stages_with_benchmark(
187199

188200
logger.info(f"Running {n_runs} iterations for benchmarking...")
189201
for i in range(n_runs):
190-
result = compute_all_stages(dataset_name, votes_dict, fixed_timestamp)
202+
result = compute_all_stages(
203+
dataset_name, votes_dict, fixed_timestamp,
204+
skip_intermediate_stages=skip_intermediate_stages,
205+
)
191206
if stages is None or i == n_runs - 1:
192207
# Keep the last run's stages
193208
stages = result["stages"]

delphi/tests/test_regression.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ def test_conversation_regression(dataset_name):
7171
# and different implementations may produce equivalent results with opposite signs
7272
comparer = ConversationComparer(ignore_pca_sign_flip=True)
7373

74-
# Run comparison
75-
result = comparer.compare_with_golden(dataset_name)
74+
# Run comparison — skip intermediate stages (empty, load-only, PCA-only,
75+
# PCA+clustering) since this test only checks overall_match. The stage-level
76+
# test below exercises intermediate stages individually.
77+
result = comparer.compare_with_golden(dataset_name, skip_intermediate_stages=True)
7678

7779
# Check for errors
7880
if "error" in result:

0 commit comments

Comments
 (0)