Skip to content

Commit 0571a2b

Browse files
fix test counts
1 parent 8790a63 commit 0571a2b

2 files changed

Lines changed: 141 additions & 139 deletions

File tree

scripts/update_performance_test_counts.py

Lines changed: 132 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -219,67 +219,63 @@ def parse_test_output(output: str) -> list[TestCount]:
219219
# The test output format is:
220220
# FAIL: test_name (step='import1', metric='queries')
221221
# AssertionError: 118 != 120 : 118 queries executed, 120 expected
222-
# OR for async tasks:
222+
#
223+
# For async tasks we may see:
223224
# FAIL: test_name (step='import1', metric='async_tasks')
224-
# AssertionError: 7 != 8 : 7 async tasks executed, 8 expected
225-
226-
# Pattern to match the full failure block:
227-
# FAIL: test_name (full.path.to.test) (step='...', metric='...')
228-
# AssertionError: actual != expected : actual ... executed, expected expected
229-
# The test name may include the full path in parentheses, so we extract just the method name
230-
failure_pattern = re.compile(
231-
r"FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)\s*\n"
232-
r".*?AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+(?:queries|async tasks?)\s+executed,\s+\d+\s+expected",
233-
re.MULTILINE | re.DOTALL,
225+
# AssertionError: Expected 7 celery tasks, but 6 were created.
226+
227+
# Parse failures by splitting into individual FAIL blocks, to avoid accidentally
228+
# associating an assertion from a different FAIL with the wrong metric.
229+
fail_header = re.compile(
230+
r"^FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)\s*$",
231+
re.MULTILINE,
234232
)
235233

236-
for match in failure_pattern.finditer(output):
234+
headers = list(fail_header.finditer(output))
235+
for idx, match in enumerate(headers):
237236
test_name = match.group(1)
238237
step = match.group(2)
239238
metric = match.group(3)
240-
actual = int(match.group(4))
241-
expected = int(match.group(5))
239+
240+
block_start = match.end()
241+
block_end = headers[idx + 1].start() if idx + 1 < len(headers) else len(output)
242+
block = output[block_start:block_end]
243+
244+
actual: int | None = None
245+
expected: int | None = None
246+
247+
if metric == "queries":
248+
m = re.search(
249+
r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+queries\s+executed,\s+\d+\s+expected",
250+
block,
251+
)
252+
if m:
253+
actual = int(m.group(1))
254+
expected = int(m.group(2))
255+
elif metric == "async_tasks":
256+
# Celery task count assertions can be in a different format.
257+
m = re.search(r"AssertionError:\s+Expected\s+(\d+)\s+celery tasks?,\s+but\s+(\d+)\s+were created\.", block)
258+
if m:
259+
expected = int(m.group(1))
260+
actual = int(m.group(2))
261+
else:
262+
m = re.search(
263+
r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+async tasks?\s+executed,\s+\d+\s+expected",
264+
block,
265+
)
266+
if m:
267+
actual = int(m.group(1))
268+
expected = int(m.group(2))
269+
270+
if actual is None or expected is None:
271+
continue
242272

243273
count = TestCount(test_name, step, metric)
244274
count.actual = actual
245275
count.expected = expected
246276
count.difference = expected - actual
247277
counts.append(count)
248278

249-
# Also try a simpler pattern in case the format is slightly different
250-
if not counts:
251-
# Look for lines with step/metric followed by AssertionError on nearby lines
252-
lines = output.split("\n")
253-
i = 0
254-
while i < len(lines):
255-
line = lines[i]
256-
257-
# Look for FAIL: test_name (may include full path in parentheses)
258-
# Format: FAIL: test_name (full.path) (step='...', metric='...')
259-
fail_match = re.search(r"FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)", line)
260-
if fail_match:
261-
test_name = fail_match.group(1)
262-
step = fail_match.group(2)
263-
metric = fail_match.group(3)
264-
# Look ahead for AssertionError
265-
for j in range(i, min(i + 15, len(lines))):
266-
assertion_match = re.search(
267-
r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+(?:queries|async tasks?)\s+executed,\s+\d+\s+expected",
268-
lines[j],
269-
)
270-
271-
if assertion_match:
272-
actual = int(assertion_match.group(1))
273-
expected = int(assertion_match.group(2))
274-
275-
count = TestCount(test_name, step, metric)
276-
count.actual = actual
277-
count.expected = expected
278-
count.difference = expected - actual
279-
counts.append(count)
280-
break
281-
i += 1
282-
283279
if counts:
284280
print(f"\n📊 Parsed {len(counts)} count mismatch(es) from test output:")
285281
for count in counts:
@@ -378,6 +374,27 @@ def update_test_file(counts: list[TestCount]):
378374

379375
content = TEST_FILE.read_text()
380376

377+
def _extract_call_span(method_content: str, call_name: str) -> tuple[int, int] | None:
378+
"""Return (start, end) indices of the first call to `call_name(...)` within method_content."""
379+
start = method_content.find(call_name)
380+
if start == -1:
381+
return None
382+
383+
open_paren = method_content.find("(", start)
384+
if open_paren == -1:
385+
return None
386+
387+
depth = 0
388+
for idx in range(open_paren, len(method_content)):
389+
ch = method_content[idx]
390+
if ch == "(":
391+
depth += 1
392+
elif ch == ")":
393+
depth -= 1
394+
if depth == 0:
395+
return start, idx + 1
396+
return None
397+
381398
# Create a mapping of test_name -> step_metric -> new_value
382399
updates = {}
383400
for count in counts:
@@ -419,100 +436,49 @@ def update_test_file(counts: list[TestCount]):
419436
test_method_start = test_match.start()
420437
test_method_end = test_match.end()
421438

422-
# Try to find _import_reimport_performance call first
423-
perf_call_pattern_import_reimport = re.compile(
424-
r"(self\._import_reimport_performance\s*\(\s*)"
425-
r"expected_num_queries1\s*=\s*(\d+)\s*,\s*"
426-
r"expected_num_async_tasks1\s*=\s*(\d+)\s*,\s*"
427-
r"expected_num_queries2\s*=\s*(\d+)\s*,\s*"
428-
r"expected_num_async_tasks2\s*=\s*(\d+)\s*,\s*"
429-
r"expected_num_queries3\s*=\s*(\d+)\s*,\s*"
430-
r"expected_num_async_tasks3\s*=\s*(\d+)\s*,"
431-
r"(\s*\))",
432-
re.DOTALL,
433-
)
434-
435-
# Try to find _deduplication_performance call
436-
perf_call_pattern_deduplication = re.compile(
437-
r"(self\._deduplication_performance\s*\(\s*)"
438-
r"expected_num_queries1\s*=\s*(\d+)\s*,\s*"
439-
r"expected_num_async_tasks1\s*=\s*(\d+)\s*,\s*"
440-
r"expected_num_queries2\s*=\s*(\d+)\s*,\s*"
441-
r"expected_num_async_tasks2\s*=\s*(\d+)\s*,"
442-
r"(\s*\))",
443-
re.DOTALL,
444-
)
445-
446-
perf_match = perf_call_pattern_import_reimport.search(test_method_content)
447-
method_type = "import_reimport"
439+
call_span = _extract_call_span(test_method_content, "self._import_reimport_performance")
448440
param_map = param_map_import_reimport
449-
param_order = [
450-
"import1_queries",
451-
"import1_async_tasks",
452-
"reimport1_queries",
453-
"reimport1_async_tasks",
454-
"reimport2_queries",
455-
"reimport2_async_tasks",
456-
]
457-
458-
if not perf_match:
459-
perf_match = perf_call_pattern_deduplication.search(test_method_content)
460-
if perf_match:
461-
method_type = "deduplication"
441+
if call_span is None:
442+
call_span = _extract_call_span(test_method_content, "self._deduplication_performance")
443+
if call_span is not None:
462444
param_map = param_map_deduplication
463-
param_order = [
464-
"first_import_queries",
465-
"first_import_async_tasks",
466-
"second_import_queries",
467-
"second_import_async_tasks",
468-
]
469445
else:
470-
print(f"⚠️ Warning: Could not find _import_reimport_performance or _deduplication_performance call in {test_name}")
446+
print(
447+
f"⚠️ Warning: Could not find _import_reimport_performance or _deduplication_performance call in {test_name}",
448+
)
471449
continue
472450

473-
# Get the indentation from the original call (first line after opening paren)
474-
call_lines = test_method_content[perf_match.start():perf_match.end()].split("\n")
475-
indent = ""
476-
for line in call_lines:
477-
if "expected_num_queries1" in line:
478-
# Extract indentation (spaces before the parameter)
479-
indent_match = re.match(r"(\s*)expected_num_queries1", line)
480-
if indent_match:
481-
indent = indent_match.group(1)
482-
break
483-
484-
# If we couldn't find indentation, use a default
485-
if not indent:
486-
indent = " " # 12 spaces default
487-
488-
replacement_parts = [perf_match.group(1)] # Opening: "self._import_reimport_performance("
489-
updated_params = []
490-
for i, step_metric in enumerate(param_order):
491-
param_name = param_map[step_metric]
492-
old_value = int(perf_match.group(i + 2)) # +2 because group 1 is the opening
493-
if step_metric in test_updates:
494-
new_value = test_updates[step_metric]
495-
if old_value != new_value:
496-
updated_params.append(f"{param_name}: {old_value}{new_value}")
497-
else:
498-
# Keep the existing value
499-
new_value = old_value
451+
call_start, call_end = call_span
452+
original_call = test_method_content[call_start:call_end]
453+
updated_call = original_call
500454

501-
replacement_parts.append(f"{indent}{param_name}={new_value},")
502-
503-
# Closing parenthesis - group number depends on method type
504-
closing_group = 8 if method_type == "import_reimport" else 6
505-
replacement_parts.append(perf_match.group(closing_group)) # Closing parenthesis
506-
replacement = "\n".join(replacement_parts)
455+
updated_params = []
456+
for step_metric, param_name in param_map.items():
457+
if step_metric not in test_updates:
458+
continue
459+
new_value = test_updates[step_metric]
460+
m = re.search(rf"({re.escape(param_name)}\s*=\s*)(\d+)", updated_call)
461+
if not m:
462+
continue
463+
old_value = int(m.group(2))
464+
if old_value == new_value:
465+
continue
466+
updated_params.append(f"{param_name}: {old_value}{new_value}")
467+
updated_call = re.sub(
468+
rf"({re.escape(param_name)}\s*=\s*)\d+",
469+
rf"\g<1>{new_value}",
470+
updated_call,
471+
count=1,
472+
)
507473

508474
if updated_params:
509475
print(f" Updated: {', '.join(updated_params)}")
510476

511-
# Replace the method call within the test method content
477+
# Replace the method call within the test method content (in-place; do not reformat)
512478
updated_method_content = (
513-
test_method_content[: perf_match.start()]
514-
+ replacement
515-
+ test_method_content[perf_match.end() :]
479+
test_method_content[:call_start]
480+
+ updated_call
481+
+ test_method_content[call_end:]
516482
)
517483

518484
# Replace the entire test method in the original content
@@ -547,6 +513,30 @@ def verify_tests(test_class: str) -> bool:
547513
return True
548514

549515

516+
def verify_and_get_mismatches(test_class: str) -> tuple[bool, list[TestCount]]:
517+
"""Run the full test class and return (success, parsed mismatches)."""
518+
print(f"Verifying tests for {test_class}...")
519+
output, return_code = run_tests(test_class)
520+
521+
success, error_msg = check_test_execution_success(output, return_code)
522+
if not success:
523+
print(f"\n❌ Test execution failed: {error_msg}")
524+
return False, []
525+
526+
counts = parse_test_output(output)
527+
if counts:
528+
print("\n❌ Some tests still have count mismatches:")
529+
for count in counts:
530+
print(
531+
f" {count.test_name} - {count.step} {count.metric}: "
532+
f"expected {count.expected}, got {count.actual}",
533+
)
534+
return False, counts
535+
536+
print("\n✅ All tests pass!")
537+
return True, []
538+
539+
550540
def main():
551541
parser = argparse.ArgumentParser(
552542
description="Update performance test query counts",
@@ -657,7 +647,17 @@ def main():
657647
if all_counts:
658648
print(f"\n{'=' * 80}")
659649
print(f"✅ Updated {len(all_counts)} count(s) across {len({c.test_name for c in all_counts})} test(s)")
660-
print("\nNext step: Run --verify to ensure all tests pass")
650+
# Some performance counts can vary depending on test ordering / keepdb state.
651+
# Do a final full-suite pass and apply any remaining mismatches so the suite passes as run in CI.
652+
print("\nRunning a final verify pass for stability...")
653+
success, suite_mismatches = verify_and_get_mismatches(args.test_class)
654+
if not success and suite_mismatches:
655+
print("\nApplying remaining mismatches from full-suite run...")
656+
update_test_file(suite_mismatches)
657+
print("\nRe-running verify...")
658+
success, _ = verify_and_get_mismatches(args.test_class)
659+
sys.exit(0 if success else 1)
660+
sys.exit(0 if success else 1)
661661
else:
662662
print(f"\n{'=' * 80}")
663663
print("\n✅ No differences found. All tests are already up to date.")

unittests/test_importers_performance.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,14 @@ def test_import_reimport_reimport_performance_pghistory_async(self):
268268

269269

270270

271-
expected_num_queries1=306,
272-
expected_num_async_tasks1=305,
271+
272+
expected_num_queries1=305,
273+
expected_num_async_tasks1=6,
273274
expected_num_queries2=232,
274-
expected_num_async_tasks2=18,
275+
expected_num_async_tasks2=17,
275276
expected_num_queries3=114,
276-
expected_num_async_tasks3=17,
277+
expected_num_async_tasks3=16,
278+
277279

278280

279281

@@ -445,10 +447,10 @@ def test_deduplication_performance_pghistory_async(self):
445447
self.system_settings(enable_deduplication=True)
446448

447449
self._deduplication_performance(
448-
expected_num_queries1=275,
449-
expected_num_async_tasks1=8,
450+
expected_num_queries1=274,
451+
expected_num_async_tasks1=7,
450452
expected_num_queries2=185,
451-
expected_num_async_tasks2=8,
453+
expected_num_async_tasks2=7,
452454
check_duplicates=False, # Async mode - deduplication happens later
453455
)
454456

0 commit comments

Comments
 (0)