@@ -219,67 +219,63 @@ def parse_test_output(output: str) -> list[TestCount]:
219219 # The test output format is:
220220 # FAIL: test_name (step='import1', metric='queries')
221221 # AssertionError: 118 != 120 : 118 queries executed, 120 expected
222- # OR for async tasks:
222+ #
223+ # For async tasks we may see:
223224 # FAIL: test_name (step='import1', metric='async_tasks')
224- # AssertionError: 7 != 8 : 7 async tasks executed, 8 expected
225-
226- # Pattern to match the full failure block:
227- # FAIL: test_name (full.path.to.test) (step='...', metric='...')
228- # AssertionError: actual != expected : actual ... executed, expected expected
229- # The test name may include the full path in parentheses, so we extract just the method name
230- failure_pattern = re .compile (
231- r"FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)\s*\n"
232- r".*?AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+(?:queries|async tasks?)\s+executed,\s+\d+\s+expected" ,
233- re .MULTILINE | re .DOTALL ,
225+ # AssertionError: Expected 7 celery tasks, but 6 were created.
226+
227+ # Parse failures by splitting into individual FAIL blocks, to avoid accidentally
228+ # associating an assertion from a different FAIL with the wrong metric.
229+ fail_header = re .compile (
230+ r"^FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)\s*$" ,
231+ re .MULTILINE ,
234232 )
235233
236- for match in failure_pattern .finditer (output ):
234+ headers = list (fail_header .finditer (output ))
235+ for idx , match in enumerate (headers ):
237236 test_name = match .group (1 )
238237 step = match .group (2 )
239238 metric = match .group (3 )
240- actual = int (match .group (4 ))
241- expected = int (match .group (5 ))
239+
240+ block_start = match .end ()
241+ block_end = headers [idx + 1 ].start () if idx + 1 < len (headers ) else len (output )
242+ block = output [block_start :block_end ]
243+
244+ actual : int | None = None
245+ expected : int | None = None
246+
247+ if metric == "queries" :
248+ m = re .search (
249+ r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+queries\s+executed,\s+\d+\s+expected" ,
250+ block ,
251+ )
252+ if m :
253+ actual = int (m .group (1 ))
254+ expected = int (m .group (2 ))
255+ elif metric == "async_tasks" :
256+ # Celery task count assertions can be in a different format.
257+ m = re .search (r"AssertionError:\s+Expected\s+(\d+)\s+celery tasks?,\s+but\s+(\d+)\s+were created\." , block )
258+ if m :
259+ expected = int (m .group (1 ))
260+ actual = int (m .group (2 ))
261+ else :
262+ m = re .search (
263+ r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+async tasks?\s+executed,\s+\d+\s+expected" ,
264+ block ,
265+ )
266+ if m :
267+ actual = int (m .group (1 ))
268+ expected = int (m .group (2 ))
269+
270+ if actual is None or expected is None :
271+ continue
242272
243273 count = TestCount (test_name , step , metric )
244274 count .actual = actual
245275 count .expected = expected
246276 count .difference = expected - actual
247277 counts .append (count )
248278
249- # Also try a simpler pattern in case the format is slightly different
250- if not counts :
251- # Look for lines with step/metric followed by AssertionError on nearby lines
252- lines = output .split ("\n " )
253- i = 0
254- while i < len (lines ):
255- line = lines [i ]
256-
257- # Look for FAIL: test_name (may include full path in parentheses)
258- # Format: FAIL: test_name (full.path) (step='...', metric='...')
259- fail_match = re .search (r"FAIL:\s+(test_\w+)\s+\([^)]+\)\s+\(step=['\"](\w+)['\"],\s*metric=['\"](\w+)['\"]\)" , line )
260- if fail_match :
261- test_name = fail_match .group (1 )
262- step = fail_match .group (2 )
263- metric = fail_match .group (3 )
264- # Look ahead for AssertionError
265- for j in range (i , min (i + 15 , len (lines ))):
266- assertion_match = re .search (
267- r"AssertionError:\s+(\d+)\s+!=\s+(\d+)\s+:\s+\d+\s+(?:queries|async tasks?)\s+executed,\s+\d+\s+expected" ,
268- lines [j ],
269- )
270-
271- if assertion_match :
272- actual = int (assertion_match .group (1 ))
273- expected = int (assertion_match .group (2 ))
274-
275- count = TestCount (test_name , step , metric )
276- count .actual = actual
277- count .expected = expected
278- count .difference = expected - actual
279- counts .append (count )
280- break
281- i += 1
282-
283279 if counts :
284280 print (f"\n 📊 Parsed { len (counts )} count mismatch(es) from test output:" )
285281 for count in counts :
@@ -378,6 +374,27 @@ def update_test_file(counts: list[TestCount]):
378374
379375 content = TEST_FILE .read_text ()
380376
377+ def _extract_call_span (method_content : str , call_name : str ) -> tuple [int , int ] | None :
378+ """Return (start, end) indices of the first call to `call_name(...)` within method_content."""
379+ start = method_content .find (call_name )
380+ if start == - 1 :
381+ return None
382+
383+ open_paren = method_content .find ("(" , start )
384+ if open_paren == - 1 :
385+ return None
386+
387+ depth = 0
388+ for idx in range (open_paren , len (method_content )):
389+ ch = method_content [idx ]
390+ if ch == "(" :
391+ depth += 1
392+ elif ch == ")" :
393+ depth -= 1
394+ if depth == 0 :
395+ return start , idx + 1
396+ return None
397+
381398 # Create a mapping of test_name -> step_metric -> new_value
382399 updates = {}
383400 for count in counts :
@@ -419,100 +436,49 @@ def update_test_file(counts: list[TestCount]):
419436 test_method_start = test_match .start ()
420437 test_method_end = test_match .end ()
421438
422- # Try to find _import_reimport_performance call first
423- perf_call_pattern_import_reimport = re .compile (
424- r"(self\._import_reimport_performance\s*\(\s*)"
425- r"expected_num_queries1\s*=\s*(\d+)\s*,\s*"
426- r"expected_num_async_tasks1\s*=\s*(\d+)\s*,\s*"
427- r"expected_num_queries2\s*=\s*(\d+)\s*,\s*"
428- r"expected_num_async_tasks2\s*=\s*(\d+)\s*,\s*"
429- r"expected_num_queries3\s*=\s*(\d+)\s*,\s*"
430- r"expected_num_async_tasks3\s*=\s*(\d+)\s*,"
431- r"(\s*\))" ,
432- re .DOTALL ,
433- )
434-
435- # Try to find _deduplication_performance call
436- perf_call_pattern_deduplication = re .compile (
437- r"(self\._deduplication_performance\s*\(\s*)"
438- r"expected_num_queries1\s*=\s*(\d+)\s*,\s*"
439- r"expected_num_async_tasks1\s*=\s*(\d+)\s*,\s*"
440- r"expected_num_queries2\s*=\s*(\d+)\s*,\s*"
441- r"expected_num_async_tasks2\s*=\s*(\d+)\s*,"
442- r"(\s*\))" ,
443- re .DOTALL ,
444- )
445-
446- perf_match = perf_call_pattern_import_reimport .search (test_method_content )
447- method_type = "import_reimport"
439+ call_span = _extract_call_span (test_method_content , "self._import_reimport_performance" )
448440 param_map = param_map_import_reimport
449- param_order = [
450- "import1_queries" ,
451- "import1_async_tasks" ,
452- "reimport1_queries" ,
453- "reimport1_async_tasks" ,
454- "reimport2_queries" ,
455- "reimport2_async_tasks" ,
456- ]
457-
458- if not perf_match :
459- perf_match = perf_call_pattern_deduplication .search (test_method_content )
460- if perf_match :
461- method_type = "deduplication"
441+ if call_span is None :
442+ call_span = _extract_call_span (test_method_content , "self._deduplication_performance" )
443+ if call_span is not None :
462444 param_map = param_map_deduplication
463- param_order = [
464- "first_import_queries" ,
465- "first_import_async_tasks" ,
466- "second_import_queries" ,
467- "second_import_async_tasks" ,
468- ]
469445 else :
470- print (f"⚠️ Warning: Could not find _import_reimport_performance or _deduplication_performance call in { test_name } " )
446+ print (
447+ f"⚠️ Warning: Could not find _import_reimport_performance or _deduplication_performance call in { test_name } " ,
448+ )
471449 continue
472450
473- # Get the indentation from the original call (first line after opening paren)
474- call_lines = test_method_content [perf_match .start ():perf_match .end ()].split ("\n " )
475- indent = ""
476- for line in call_lines :
477- if "expected_num_queries1" in line :
478- # Extract indentation (spaces before the parameter)
479- indent_match = re .match (r"(\s*)expected_num_queries1" , line )
480- if indent_match :
481- indent = indent_match .group (1 )
482- break
483-
484- # If we couldn't find indentation, use a default
485- if not indent :
486- indent = " " # 12 spaces default
487-
488- replacement_parts = [perf_match .group (1 )] # Opening: "self._import_reimport_performance("
489- updated_params = []
490- for i , step_metric in enumerate (param_order ):
491- param_name = param_map [step_metric ]
492- old_value = int (perf_match .group (i + 2 )) # +2 because group 1 is the opening
493- if step_metric in test_updates :
494- new_value = test_updates [step_metric ]
495- if old_value != new_value :
496- updated_params .append (f"{ param_name } : { old_value } → { new_value } " )
497- else :
498- # Keep the existing value
499- new_value = old_value
451+ call_start , call_end = call_span
452+ original_call = test_method_content [call_start :call_end ]
453+ updated_call = original_call
500454
501- replacement_parts .append (f"{ indent } { param_name } ={ new_value } ," )
502-
503- # Closing parenthesis - group number depends on method type
504- closing_group = 8 if method_type == "import_reimport" else 6
505- replacement_parts .append (perf_match .group (closing_group )) # Closing parenthesis
506- replacement = "\n " .join (replacement_parts )
455+ updated_params = []
456+ for step_metric , param_name in param_map .items ():
457+ if step_metric not in test_updates :
458+ continue
459+ new_value = test_updates [step_metric ]
460+ m = re .search (rf"({ re .escape (param_name )} \s*=\s*)(\d+)" , updated_call )
461+ if not m :
462+ continue
463+ old_value = int (m .group (2 ))
464+ if old_value == new_value :
465+ continue
466+ updated_params .append (f"{ param_name } : { old_value } → { new_value } " )
467+ updated_call = re .sub (
468+ rf"({ re .escape (param_name )} \s*=\s*)\d+" ,
469+ rf"\g<1>{ new_value } " ,
470+ updated_call ,
471+ count = 1 ,
472+ )
507473
508474 if updated_params :
509475 print (f" Updated: { ', ' .join (updated_params )} " )
510476
511- # Replace the method call within the test method content
477+ # Replace the method call within the test method content (in-place; do not reformat)
512478 updated_method_content = (
513- test_method_content [: perf_match . start () ]
514- + replacement
515- + test_method_content [perf_match . end () :]
479+ test_method_content [:call_start ]
480+ + updated_call
481+ + test_method_content [call_end :]
516482 )
517483
518484 # Replace the entire test method in the original content
@@ -547,6 +513,30 @@ def verify_tests(test_class: str) -> bool:
547513 return True
548514
549515
516+ def verify_and_get_mismatches (test_class : str ) -> tuple [bool , list [TestCount ]]:
517+ """Run the full test class and return (success, parsed mismatches)."""
518+ print (f"Verifying tests for { test_class } ..." )
519+ output , return_code = run_tests (test_class )
520+
521+ success , error_msg = check_test_execution_success (output , return_code )
522+ if not success :
523+ print (f"\n ❌ Test execution failed: { error_msg } " )
524+ return False , []
525+
526+ counts = parse_test_output (output )
527+ if counts :
528+ print ("\n ❌ Some tests still have count mismatches:" )
529+ for count in counts :
530+ print (
531+ f" { count .test_name } - { count .step } { count .metric } : "
532+ f"expected { count .expected } , got { count .actual } " ,
533+ )
534+ return False , counts
535+
536+ print ("\n ✅ All tests pass!" )
537+ return True , []
538+
539+
550540def main ():
551541 parser = argparse .ArgumentParser (
552542 description = "Update performance test query counts" ,
@@ -657,7 +647,17 @@ def main():
657647 if all_counts :
658648 print (f"\n { '=' * 80 } " )
659649 print (f"✅ Updated { len (all_counts )} count(s) across { len ({c .test_name for c in all_counts })} test(s)" )
660- print ("\n Next step: Run --verify to ensure all tests pass" )
650+ # Some performance counts can vary depending on test ordering / keepdb state.
651+ # Do a final full-suite pass and apply any remaining mismatches so the suite passes as run in CI.
652+ print ("\n Running a final verify pass for stability..." )
653+ success , suite_mismatches = verify_and_get_mismatches (args .test_class )
654+ if not success and suite_mismatches :
655+ print ("\n Applying remaining mismatches from full-suite run..." )
656+ update_test_file (suite_mismatches )
657+ print ("\n Re-running verify..." )
658+ success , _ = verify_and_get_mismatches (args .test_class )
659+ sys .exit (0 if success else 1 )
660+ sys .exit (0 if success else 1 )
661661 else :
662662 print (f"\n { '=' * 80 } " )
663663 print ("\n ✅ No differences found. All tests are already up to date." )
0 commit comments