Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions apps/worker/services/test_analytics/ta_process_flakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,27 @@ def fetch_current_flakes(repo_id: int) -> dict[bytes, Flake]:
}


def get_testruns(upload: ReportSession) -> QuerySet[Testrun]:
upload_filter = Q(upload_id=upload.id)

# we won't process flakes for testruns older than 1 day
return Testrun.objects.filter(
Q(timestamp__gte=timezone.now() - timedelta(days=1)) & upload_filter
def get_testruns_for_uploads(
uploads: list[ReportSession],
) -> dict[int, list[Testrun]]:
"""Fetch all testruns for a list of uploads in a single query, grouped by upload_id."""
upload_ids = [upload.id for upload in uploads]
testruns = Testrun.objects.filter(
Q(timestamp__gte=timezone.now() - timedelta(days=1))
& Q(upload_id__in=upload_ids)
).order_by("timestamp")

result: dict[int, list[Testrun]] = {uid: [] for uid in upload_ids}
for testrun in testruns:
result[testrun.upload_id].append(testrun)
return result


def handle_pass(curr_flakes: dict[bytes, Flake], test_id: bytes):
def handle_pass(
curr_flakes: dict[bytes, Flake],
expired_flakes: list[Flake],
test_id: bytes,
):
# possible that we expire it and stop caring about it
if test_id not in curr_flakes:
return
Expand All @@ -51,7 +62,7 @@ def handle_pass(curr_flakes: dict[bytes, Flake], test_id: bytes):
curr_flakes[test_id].count += 1
if curr_flakes[test_id].recent_passes_count == 30:
curr_flakes[test_id].end_date = timezone.now()
curr_flakes[test_id].save()
expired_flakes.append(curr_flakes[test_id])
del curr_flakes[test_id]


Expand Down Expand Up @@ -81,18 +92,19 @@ def handle_failure(

@sentry_sdk.trace
def process_single_upload(
upload: ReportSession, curr_flakes: dict[bytes, Flake], repo_id: int
testruns: list[Testrun],
curr_flakes: dict[bytes, Flake],
expired_flakes: list[Flake],
repo_id: int,
):
testruns = get_testruns(upload)

for testrun in testruns:
test_id = bytes(testrun.test_id)
match testrun.outcome:
case "pass":
if test_id not in curr_flakes:
continue

handle_pass(curr_flakes, test_id)
handle_pass(curr_flakes, expired_flakes, test_id)
case "failure" | "flaky_fail" | "error":
handle_failure(curr_flakes, test_id, testrun, repo_id)
case _:
Expand All @@ -106,7 +118,7 @@ def process_flakes_for_commit(repo_id: int, commit_id: str):
log.info(
"process_flakes_for_commit: starting processing",
)
uploads = get_relevant_uploads(repo_id, commit_id)
uploads = list(get_relevant_uploads(repo_id, commit_id))

log.info(
"process_flakes_for_commit: fetched uploads",
Expand All @@ -120,8 +132,13 @@ def process_flakes_for_commit(repo_id: int, commit_id: str):
extra={"flakes": [flake.test_id.hex() for flake in curr_flakes.values()]},
)

testruns_by_upload = get_testruns_for_uploads(uploads)
expired_flakes: list[Flake] = []

for upload in uploads:
process_single_upload(upload, curr_flakes, repo_id)
process_single_upload(
testruns_by_upload[upload.id], curr_flakes, expired_flakes, repo_id
)
log.info(
"process_flakes_for_commit: processed upload",
extra={"upload": upload.id},
Expand All @@ -132,6 +149,11 @@ def process_flakes_for_commit(repo_id: int, commit_id: str):
extra={"flakes": [flake.test_id.hex() for flake in curr_flakes.values()]},
)

Flake.objects.bulk_update(
expired_flakes,
fields=["end_date", "count", "recent_passes_count"],
)

Flake.objects.bulk_create(
curr_flakes.values(),
update_conflicts=True,
Expand Down
Loading