-
Notifications
You must be signed in to change notification settings - Fork 96
Expand file tree
/
Copy pathindex.py
More file actions
1264 lines (1063 loc) · 50.2 KB
/
Copy pathindex.py
File metadata and controls
1264 lines (1063 loc) · 50.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import json
import logging
import os
import re
import time
from datetime import datetime, timedelta
from decimal import Decimal
import boto3
sqs = boto3.client("sqs")
athena = boto3.client("athena")
lambda_client = boto3.client("lambda")
def _invoke_mlflow_logger(test_run_id, metrics, config=None):
"""Asynchronously invoke the MLflow logger function if configured."""
mlflow_logger_arn = os.environ.get("MLFLOW_LOGGER_FUNCTION_ARN")
if not mlflow_logger_arn:
return
try:
payload = {
"experiment_name": test_run_id,
"metrics": metrics,
"params": {
"test_run_id": test_run_id,
},
"tags": {
"source": "test_results_resolver",
},
}
if config:
payload["config"] = config
lambda_client.invoke(
FunctionName=mlflow_logger_arn,
InvocationType="Event", # async, fire-and-forget
Payload=json.dumps(payload, cls=DecimalEncoder),
)
logger.info(f"Invoked MLflow logger for test run: {test_run_id}")
except Exception as e:
logger.warning(f"Failed to invoke MLflow logger for {test_run_id}: {e}")
# Custom JSON encoder to handle Decimal objects from DynamoDB
class DecimalEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
return super(DecimalEncoder, self).default(obj)
logger = logging.getLogger()
logger.setLevel(os.environ.get("LOG_LEVEL", "INFO"))
# SQL-injection defense for Athena queries
# -----------------------------------------
# The Athena f-string queries in this file (marked `# nosec B608`) interpolate
# identifier-like values (test_run_id, database name) that Athena does NOT
# support as bind parameters. To prevent SQL injection we enforce a strict
# allow-list on every such value BEFORE it is ever placed in a query:
#
# - Pattern: ^[a-zA-Z0-9_\-./]+$ (identifiers, UUID fragments, S3-style paths)
# - No quotes (' "), no semicolons, no whitespace, no SQL metacharacters,
# no comment markers (--, /* */), no parentheses — nothing that can
# escape an identifier context or terminate a statement.
# - Called at the top of every resolver that builds a query; on failure
# raises ValueError and the query is never executed.
#
# Bandit's B608 flags string-built SQL generally; it cannot see that the
# interpolated values are constrained to this grammar. Each `# nosec B608`
# annotation in this file is therefore justified by a preceding
# `_validate_sql_input()` call on every interpolated value.
_SAFE_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_\-./]+$")
def _validate_sql_input(value, name):
"""Validate that a value is safe for use in Athena SQL identifier context.
See the module-level comment above for the rationale and threat model.
"""
if not value or not _SAFE_ID_PATTERN.match(value):
raise ValueError(f"{name} contains invalid characters: {value}")
return value
dynamodb = boto3.resource("dynamodb")
def handler(event, context):
"""Handle both GraphQL resolver and SQS events"""
# Check if this is an SQS event
if "Records" in event:
return handle_cache_update_request(event, context)
# Otherwise handle as GraphQL resolver
field_name = event["info"]["fieldName"]
if field_name == "getTestRuns":
args = event.get("arguments", {})
start_date_time = args.get("startDateTime")
end_date_time = args.get("endDateTime")
time_period_hours = args.get("timePeriodHours", 2)
if start_date_time and end_date_time:
start_iso, end_iso = start_date_time, end_date_time
else:
end_iso = datetime.utcnow().isoformat() + "Z"
start_iso = (
datetime.utcnow() - timedelta(hours=time_period_hours)
).isoformat() + "Z"
logger.info(f"Processing getTestRuns request: {start_iso} → {end_iso}")
return get_test_runs(start_iso, end_iso)
elif field_name == "getTestRun":
test_run_id = event["arguments"]["testRunId"]
logger.info(f"Processing getTestRun request for test run: {test_run_id}")
return get_test_results(test_run_id)
elif field_name == "getTestRunStatus":
test_run_id = event["arguments"]["testRunId"]
logger.info(f"Processing getTestRunStatus request for test run: {test_run_id}")
return get_test_run_status(test_run_id)
elif field_name == "compareTestRuns":
test_run_ids = event["arguments"]["testRunIds"]
logger.info(f"Processing compareTestRuns request for test runs: {test_run_ids}")
return compare_test_runs(test_run_ids)
raise ValueError(f"Unknown field: {field_name}")
def handle_cache_update_request(event, context):
"""Process SQS messages to calculate and cache test result metrics"""
for record in event["Records"]:
try:
message = json.loads(record["body"])
test_run_id = message["testRunId"]
logger.info(f"Processing cache update for test run: {test_run_id}")
# Calculate metrics
aggregated_metrics = _aggregate_test_run_metrics(test_run_id)
# Cache the metrics (including new confidence_metrics from Stickler v0.4.0+)
metrics_to_cache = {
"overallAccuracy": aggregated_metrics.get("overall_accuracy"),
"weightedOverallScores": aggregated_metrics.get(
"weighted_overall_scores", {}
),
"averageConfidence": aggregated_metrics.get("average_confidence"),
"confidenceMetrics": aggregated_metrics.get("confidence_metrics"),
"accuracyBreakdown": aggregated_metrics.get("accuracy_breakdown", {}),
"confusionMatrix": aggregated_metrics.get("confusion_matrix", {}),
"fieldMetrics": aggregated_metrics.get("field_metrics", {}),
"splitClassificationMetrics": aggregated_metrics.get(
"split_classification_metrics", {}
),
"totalCost": aggregated_metrics.get("total_cost", 0),
"costBreakdown": aggregated_metrics.get("cost_breakdown", {}),
}
table = dynamodb.Table(os.environ["TRACKING_TABLE"]) # type: ignore[attr-defined]
table.update_item(
Key={"PK": f"testrun#{test_run_id}", "SK": "metadata"},
UpdateExpression="SET testRunResult = :metrics",
ExpressionAttributeValues={
":metrics": float_to_decimal(metrics_to_cache)
},
)
logger.info(f"Successfully cached metrics for test run: {test_run_id}")
except Exception as e:
logger.error(
f"Failed to process cache update for {record.get('body', 'unknown')}: {e}"
)
# Don't raise - let other messages in batch continue processing
def float_to_decimal(obj):
"""Convert float values to Decimal for DynamoDB storage"""
if isinstance(obj, float):
return Decimal(str(obj))
elif isinstance(obj, dict):
return {k: float_to_decimal(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [float_to_decimal(v) for v in obj]
return obj
def compare_test_runs(test_run_ids):
"""Compare multiple test runs"""
logger.info(f"Comparing test runs: {test_run_ids}")
if not test_run_ids or len(test_run_ids) < 2:
logger.warning(
f"Insufficient test runs for comparison: {len(test_run_ids) if test_run_ids else 0}"
)
return {"metrics": [], "configs": []}
# Get results for each test run
results = []
configs = []
for test_run_id in test_run_ids:
logger.info(f"Getting results for test run: {test_run_id}")
test_result = get_test_results(test_run_id)
if test_result:
logger.info(f"Found results for {test_run_id}: {test_result.keys()}")
results.append(test_result)
config = _get_test_run_config(test_run_id)
configs.append({"testRunId": test_run_id, "config": config})
else:
logger.warning(f"No results found for test run: {test_run_id}")
logger.info(f"Total results found: {len(results)}")
if len(results) < 2:
logger.warning(f"Insufficient results for comparison: {len(results)}")
return {"metrics": [], "configs": []}
metrics_comparison = {result["testRunId"]: result for result in results}
configs_comparison = _build_config_comparison(configs)
logger.info(f"Configs data: {configs}")
logger.info(f"Config comparison result: {configs_comparison}")
comparison_result = {"metrics": metrics_comparison, "configs": configs_comparison}
logger.info(f"Final comparison result: {comparison_result}")
return comparison_result
def _format_datetime(dt_str):
"""Format datetime string for GraphQL AWSDateTime type"""
if not dt_str:
return None
# Add Z suffix if not present
return dt_str + "Z" if not dt_str.endswith("Z") else dt_str
def _count_completed_documents(table, test_run_id, files):
"""
Count how many documents completed evaluation successfully.
Uses batch_get_item for efficiency instead of sequential get_item calls.
Args:
table: DynamoDB table resource
test_run_id: Test run identifier
files: List of file names in the test run
Returns:
int: Number of documents with EvaluationStatus='COMPLETED'
"""
if not files:
return 0
completed_count = 0
table_name = table.table_name
dynamodb_client = boto3.client('dynamodb')
# Build object keys
object_keys = [f"{test_run_id}/{file_name}" for file_name in files]
# DynamoDB batch_get_item supports up to 100 keys per batch
batch_size = 100
for i in range(0, len(object_keys), batch_size):
batch = object_keys[i:i + batch_size]
keys = [{'PK': {'S': f"doc#{key}"}, 'SK': {'S': 'none'}} for key in batch]
try:
response = dynamodb_client.batch_get_item(
RequestItems={
table_name: {
'Keys': keys
}
}
)
# Count completed evaluations
for item in response.get('Responses', {}).get(table_name, []):
eval_status = item.get('EvaluationStatus', {}).get('S', '').upper()
if eval_status == 'COMPLETED':
completed_count += 1
# Handle unprocessed keys (throttling, etc.)
unprocessed = response.get('UnprocessedKeys', {}).get(table_name, {})
if unprocessed:
logger.warning(f"Batch get had {len(unprocessed.get('Keys', []))} unprocessed keys")
except Exception as e:
logger.error(f"Batch get failed for batch starting at index {i}: {str(e)}")
return completed_count
def get_test_results(test_run_id):
"""Get detailed test results for a specific test run"""
table = dynamodb.Table(os.environ["TRACKING_TABLE"]) # type: ignore[attr-defined] # type: ignore[attr-defined]
# Get test run metadata
response = table.get_item(Key={"PK": f"testrun#{test_run_id}", "SK": "metadata"})
if "Item" not in response:
raise ValueError(f"Test run {test_run_id} not found")
metadata = response["Item"]
current_status = metadata.get("Status")
# Update status if not completed
if current_status not in ["COMPLETE", "PARTIAL_COMPLETE", "ABORTED"]:
status_result = get_test_run_status(test_run_id)
if status_result:
current_status = status_result["status"]
# Refresh metadata after status update
response = table.get_item(
Key={"PK": f"testrun#{test_run_id}", "SK": "metadata"}
)
if "Item" in response:
metadata = response["Item"]
# Raise error if status is still not complete
if current_status not in ["COMPLETE", "PARTIAL_COMPLETE", "ABORTED"]:
raise ValueError(
f"Test run {test_run_id} is not complete. Current status: {current_status}"
)
# Check if cached results exist and are complete
cached_metrics = metadata.get("testRunResult")
if cached_metrics is not None:
logger.info(f"Retrieved cached metrics for test run: {test_run_id}")
# Check if cached data needs recalculation
cached_scores = cached_metrics.get("weightedOverallScores")
if (
"splitClassificationMetrics" not in cached_metrics
or "confusionMatrix" not in cached_metrics
or "fieldMetrics" not in cached_metrics
or isinstance(cached_scores, list)
):
logger.info(
f"Cached metrics incomplete or outdated, recalculating for test run: {test_run_id}"
)
# Force recalculation by falling through to aggregation logic
else:
# For ABORTED status, count completed files on first call and persist to DB
completed_files_count = metadata.get("CompletedFiles", 0)
completed_files_counted = metadata.get("CompletedFilesCounted", False)
# Only re-count if we haven't counted before (tracked by CompletedFilesCounted flag)
if current_status == "ABORTED" and not completed_files_counted:
files = metadata.get("Files", [])
if files:
completed_files_count = _count_completed_documents(table, test_run_id, files)
logger.info(f"Counted {completed_files_count} completed documents for aborted test run {test_run_id}")
# Persist the count and flag to database
try:
table.update_item(
Key={"PK": f"testrun#{test_run_id}", "SK": "metadata"},
UpdateExpression="SET CompletedFiles = :completed_files, CompletedFilesCounted = :counted",
ExpressionAttributeValues={
":completed_files": completed_files_count,
":counted": True
}
)
logger.info(f"Updated CompletedFiles to {completed_files_count} and set CompletedFilesCounted=True for test run {test_run_id}")
except Exception as e:
logger.warning(f"Failed to update CompletedFiles for {test_run_id}: {str(e)}")
# Use cached metrics but get dynamic fields from current metadata
return {
"testRunId": test_run_id,
"testSetId": metadata.get("TestSetId"),
"testSetName": metadata.get("TestSetName"),
"status": current_status,
"filesCount": metadata.get("FilesCount", 0),
"completedFiles": completed_files_count,
"failedFiles": metadata.get("FailedFiles", 0),
"overallAccuracy": cached_metrics.get("overallAccuracy"),
"weightedOverallScores": cached_metrics.get(
"weightedOverallScores", {}
),
"averageConfidence": cached_metrics.get("averageConfidence"),
"confidenceMetrics": cached_metrics.get("confidenceMetrics"),
"accuracyBreakdown": cached_metrics.get("accuracyBreakdown", {}),
"confusionMatrix": cached_metrics.get("confusionMatrix", {}),
"fieldMetrics": cached_metrics.get("fieldMetrics", {}),
"splitClassificationMetrics": cached_metrics.get(
"splitClassificationMetrics", {}
),
"totalCost": cached_metrics.get("totalCost", 0),
"costBreakdown": cached_metrics.get("costBreakdown", {}),
"createdAt": _format_datetime(metadata.get("CreatedAt")),
"completedAt": _format_datetime(metadata.get("CompletedAt")),
"context": metadata.get("Context"),
"configVersion": metadata.get("ConfigVersion"),
"config": _get_test_run_config(test_run_id),
}
else:
# No aggregate metrics have been cached yet. This happens when all
# files finished processing but the evaluation aggregation step hasn't
# written testRunResult (still running, or it timed out / failed on a
# large run). Don't raise — that surfaces as an opaque error and the UI
# spins on "Loading..." forever. Return a structured partial TestRun so
# the UI can render the in-progress status instead.
if current_status == "ABORTED":
logger.info(
f"Test run {test_run_id} aborted; aggregate metrics not yet available"
)
else:
logger.info(
f"Test run {test_run_id} processing complete; "
"aggregate metrics not yet available (evaluation in progress)"
)
return {
"testRunId": test_run_id,
"testSetId": metadata.get("TestSetId"),
"testSetName": metadata.get("TestSetName"),
"status": current_status,
"filesCount": metadata.get("FilesCount", 0),
"completedFiles": metadata.get("CompletedFiles", 0),
"failedFiles": metadata.get("FailedFiles", 0),
"createdAt": _format_datetime(metadata.get("CreatedAt")),
"completedAt": _format_datetime(metadata.get("CompletedAt")),
"context": metadata.get("Context"),
"configVersion": metadata.get("ConfigVersion"),
}
def _query_test_runs_from_gsi(table, start_iso, end_iso):
"""Query test runs from TypeDateIndex GSI instead of scanning the full table.
Uses GSI to find testrun keys efficiently, then BatchGetItem for full records
(GSI projection doesn't include all fields like Context, ConfigVersion, etc.).
Falls back to scan if GSI query returns no results (backfill may not be complete).
"""
from boto3.dynamodb.conditions import Key
gsi_items = []
query_kwargs = {
"IndexName": "TypeDateIndex",
"KeyConditionExpression": Key("ItemType").eq("testrun")
& Key("InitialEventTime").between(start_iso, end_iso),
"ScanIndexForward": False, # Newest first
"ProjectionExpression": "PK, SK",
}
try:
while True:
response = table.query(**query_kwargs)
gsi_items.extend(response.get("Items", []))
if "LastEvaluatedKey" not in response:
break
query_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]
logger.info(f"GSI query returned {len(gsi_items)} test run keys")
# If GSI returned results, fetch full records via BatchGetItem
if gsi_items:
items = []
keys = [{"PK": item["PK"], "SK": item["SK"]} for item in gsi_items]
table_name = table.table_name
# DynamoDB BatchGetItem supports max 100 keys per call
for i in range(0, len(keys), 100):
batch_keys = keys[i : i + 100]
batch_response = boto3.resource("dynamodb").batch_get_item(
RequestItems={table_name: {"Keys": batch_keys}}
)
items.extend(batch_response.get("Responses", {}).get(table_name, []))
logger.info(f"BatchGetItem returned {len(items)} full test run records")
return items
# Fallback: GSI may not have ItemType yet (backfill pending).
# Try scan with CreatedAt filter as fallback.
logger.info(
"GSI returned 0 results, falling back to scan (backfill may be pending)"
)
except Exception as e:
logger.warning(f"GSI query failed, falling back to scan: {e}")
# Fallback scan
items = []
scan_kwargs = {
"FilterExpression": "begins_with(PK, :pk) AND SK = :sk AND CreatedAt >= :start AND CreatedAt <= :end",
"ExpressionAttributeValues": {
":pk": "testrun#",
":sk": "metadata",
":start": start_iso,
":end": end_iso,
},
}
while True:
response = table.scan(**scan_kwargs)
items.extend(response.get("Items", []))
if "LastEvaluatedKey" not in response:
break
scan_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]
logger.info(f"Fallback scan returned {len(items)} test runs")
return items
def _build_test_run_list(items):
"""Build sorted test run list from raw DynamoDB items."""
test_runs = []
for item in items:
display_status = item.get("Status")
# Show EVALUATING for completed tests without metrics, but keep ABORTED as-is
if display_status in ["COMPLETE", "PARTIAL_COMPLETE"] and not item.get(
"testRunResult"
):
display_status = "EVALUATING"
test_runs.append(
{
"testRunId": item["TestRunId"],
"testSetId": item.get("TestSetId"),
"testSetName": item.get("TestSetName"),
"status": display_status,
"filesCount": item.get("FilesCount", 0),
"completedFiles": item.get("CompletedFiles", 0),
"failedFiles": item.get("FailedFiles", 0),
"createdAt": _format_datetime(item.get("CreatedAt")),
"completedAt": _format_datetime(item.get("CompletedAt")),
"context": item.get("Context"),
"configVersion": item.get("ConfigVersion"),
}
)
test_runs.sort(
key=lambda r: r.get("createdAt") or "1970-01-01T00:00:00Z", reverse=True
)
return test_runs
def get_test_runs(start_iso, end_iso):
"""Get list of test runs within a date range"""
table = dynamodb.Table(os.environ["TRACKING_TABLE"]) # type: ignore[attr-defined]
logger.info(f"Fetching test runs between: {start_iso} and {end_iso}")
items = _query_test_runs_from_gsi(table, start_iso, end_iso)
logger.info(f"Test runs found: {len(items)}")
return _build_test_run_list(items)
def _calculate_completed_at(test_run_id, files, table):
"""Calculate completedAt timestamp from document CompletionTime"""
latest_completion_time = None
for file_key in files:
doc_response = table.get_item(
Key={"PK": f"doc#{test_run_id}/{file_key}", "SK": "none"}
)
if "Item" in doc_response:
doc_item = doc_response["Item"]
completion_time = doc_item.get("CompletionTime")
if completion_time:
completion_time = completion_time.replace("+00:00", "Z")
if (
not latest_completion_time
or completion_time > latest_completion_time
):
latest_completion_time = completion_time
return latest_completion_time
def get_test_run_status(test_run_id):
"""Get lightweight status for specific test run - checks both document and evaluation status"""
table = dynamodb.Table(os.environ["TRACKING_TABLE"]) # type: ignore[attr-defined]
try:
logger.info(f"Getting test run status for: {test_run_id}")
# Get test run metadata
response = table.get_item(
Key={"PK": f"testrun#{test_run_id}", "SK": "metadata"}
)
if "Item" not in response:
logger.warning(f"Test run metadata not found for: {test_run_id}")
return None
item = response["Item"]
files = item.get("Files", [])
files_count = item.get("FilesCount", 0)
logger.info(f"Test run {test_run_id}: Found {files_count} files")
# If test run was manually aborted, return ABORTED status without recalculation
stored_status = item.get("Status", "RUNNING")
if stored_status == "ABORTED":
logger.info(f"Test run {test_run_id} is ABORTED, returning stored status")
return {
"testRunId": test_run_id,
"status": "ABORTED",
"progress": 100,
"completedFiles": item.get("CompletedFiles", 0),
"filesCount": files_count,
"evaluatingFiles": 0,
"failedFiles": item.get("FailedFiles", 0),
}
# Always check actual document status from tracking table
completed_files = 0
processing_failed_files = 0 # Only count processing failures found during scan
evaluating_files = 0
queued_files = 0
for file_key in files:
logger.info(f"Checking file: {file_key} for test run: {test_run_id}")
doc_response = table.get_item(
Key={"PK": f"doc#{test_run_id}/{file_key}", "SK": "none"}
)
if "Item" in doc_response:
doc_status = doc_response["Item"].get("ObjectStatus", "QUEUED")
eval_status = doc_response["Item"].get("EvaluationStatus")
logger.info(
f"File {file_key}: ObjectStatus={doc_status}, EvaluationStatus={eval_status}"
)
if doc_status == "COMPLETED":
# Check if evaluation is also complete
if eval_status == "COMPLETED":
completed_files += 1
logger.info(f"File {file_key}: counted as completed")
elif eval_status == "RUNNING":
evaluating_files += 1
logger.info(f"File {file_key}: counted as evaluating")
elif eval_status is None:
# Document completed but evaluation not started yet
evaluating_files += 1
logger.info(
f"File {file_key}: counted as evaluating (eval not started)"
)
elif eval_status == "FAILED":
# Evaluation failed - count as failed
processing_failed_files += 1
logger.info(f"File {file_key}: counted as failed (eval failed)")
elif eval_status == "NO_BASELINE":
# No baseline data available - count as completed
completed_files += 1
logger.info(
f"File {file_key}: counted as completed (no baseline data)"
)
else:
# Unknown evaluation status - count as evaluating
evaluating_files += 1
logger.info(
f"File {file_key}: counted as evaluating (unknown eval status: {eval_status})"
)
elif doc_status == "FAILED":
processing_failed_files += 1
logger.info(f"File {file_key}: counted as failed")
elif doc_status == "ABORTED":
# Count aborted documents as processing failures
processing_failed_files += 1
logger.info(f"File {file_key}: counted as failed (aborted)")
elif doc_status == "QUEUED":
queued_files += 1
logger.info(f"File {file_key}: counted as queued")
else:
logger.info(
f"File {file_key}: still processing (status: {doc_status})"
)
else:
logger.warning(f"Document not found: doc#{test_run_id}/{file_key}")
# Count missing documents as queued (not yet created)
queued_files += 1
# Calculate total failed files
baseline_failed_files = item.get(
"BaselineFailedFiles", 0
) # Set by copier, never updated
total_failed_files = (
baseline_failed_files + processing_failed_files
) # Recalculated each call
logger.info(
f"Test run {test_run_id} counts: completed={completed_files}, processing_failed={processing_failed_files}, baseline_failed={baseline_failed_files}, total_failed={total_failed_files}, evaluating={evaluating_files}, queued={queued_files}, total={files_count}"
)
# Determine overall test run status based on document and evaluation states
if (
completed_files == files_count
and files_count > 0
and total_failed_files == 0
):
overall_status = "COMPLETE"
elif (
total_failed_files > 0
and (completed_files + total_failed_files + evaluating_files) == files_count
):
overall_status = "PARTIAL_COMPLETE"
elif evaluating_files > 0:
overall_status = "EVALUATING"
elif queued_files == files_count:
overall_status = "QUEUED" # All files are still queued
elif (
completed_files + total_failed_files + evaluating_files + queued_files
< files_count
):
overall_status = "RUNNING" # Some files are actively processing
else:
overall_status = item.get("Status", "RUNNING")
# Auto-update database metadata if calculated status differs from stored status
stored_status = item.get("Status", "RUNNING")
if overall_status != stored_status:
# Calculate completedAt from document completion times if status is complete
calculated_completed_at = item.get("CompletedAt")
if (
overall_status in ["COMPLETE", "PARTIAL_COMPLETE"]
and not calculated_completed_at
):
calculated_completed_at = _calculate_completed_at(
test_run_id, files, table
)
logger.info(
f"Auto-updating test run {test_run_id} status from {stored_status} to {overall_status}"
)
try:
table.update_item(
Key={"PK": f"testrun#{test_run_id}", "SK": "metadata"},
UpdateExpression="SET #status = :status, #completedAt = :completedAt, CompletedFiles = :completedFiles, FailedFiles = :failedFiles",
ExpressionAttributeNames={
"#status": "Status",
"#completedAt": "CompletedAt",
},
ExpressionAttributeValues={
":status": overall_status,
":completedAt": calculated_completed_at,
":completedFiles": completed_files,
":failedFiles": total_failed_files,
},
)
logger.info(
f"Successfully updated test run {test_run_id} status to {overall_status}"
)
# Queue metric calculation for completed test runs
if overall_status in ["COMPLETE", "PARTIAL_COMPLETE"] and not item.get(
"testRunResult"
):
try:
queue_url = os.environ.get("TEST_RESULT_CACHE_UPDATE_QUEUE_URL")
if queue_url:
sqs.send_message(
QueueUrl=queue_url,
MessageBody=json.dumps({"testRunId": test_run_id}),
)
logger.info(
f"Queued cache update for test run: {test_run_id}"
)
except Exception as e:
logger.warning(
f"Failed to queue cache update for {test_run_id}: {e}"
)
except Exception as e:
logger.error(
f"Failed to auto-update test run {test_run_id} status: {e}"
)
# Report EVALUATING to caller until cached metrics are available
display_status = overall_status
if display_status in ["COMPLETE", "PARTIAL_COMPLETE"] and not item.get(
"testRunResult"
):
display_status = "EVALUATING"
progress = (
((completed_files + total_failed_files) / files_count * 100)
if files_count > 0
else 0
)
result = {
"testRunId": test_run_id,
"status": display_status,
"filesCount": files_count,
"completedFiles": completed_files,
"failedFiles": total_failed_files,
"evaluatingFiles": evaluating_files,
"progress": progress,
}
logger.info(f"Test run {test_run_id} final result: {result}")
return result
except Exception as e:
logger.error(f"Error getting test run status for {test_run_id}: {e}")
return None
def _aggregate_test_run_metrics(test_run_id):
"""Aggregate metrics using Stickler bulk evaluator (with Athena fallback)"""
# Fetch config for MLflow logging (best-effort, don't block on failure)
test_run_config = None
try:
test_run_config = _get_test_run_config(test_run_id)
except Exception as e:
logger.warning(f"Failed to fetch config for MLflow logging: {e}")
# Try Stickler-based aggregation via Lambda function
test_execution_aggregation_arn = os.environ.get(
"TEST_EXECUTION_AGGREGATION_FUNCTION_ARN"
)
if test_execution_aggregation_arn:
try:
# Invoke the test execution aggregation function
response = lambda_client.invoke(
FunctionName=test_execution_aggregation_arn,
InvocationType="RequestResponse",
Payload=json.dumps({"test_run_id": test_run_id}),
)
# Parse response
payload = json.loads(response["Payload"].read())
if payload.get("statusCode") == 200:
stickler_metrics = json.loads(payload["body"])
# If we got valid results, use them and get split metrics and confidence from Athena
if stickler_metrics.get("document_count", 0) > 0:
logger.info(
f"Using Stickler aggregation for test run {test_run_id}"
)
# Get split metrics from Athena
athena_metrics = _get_evaluation_metrics_from_athena(test_run_id)
cost_data = _get_cost_data_from_athena(test_run_id)
# Prefer Stickler confidence over Athena (Stickler v0.4.0+ has better calibration)
stickler_avg_confidence = stickler_metrics.get("average_confidence")
athena_avg_confidence = athena_metrics.get("average_confidence")
# Use Stickler confidence if available, fallback to Athena
avg_confidence = (
stickler_avg_confidence
if stickler_avg_confidence is not None
else athena_avg_confidence
)
# Merge Stickler metrics with Athena split metrics
merged_metrics = {
**stickler_metrics,
"average_confidence": avg_confidence,
"split_classification_metrics": athena_metrics.get(
"split_classification_metrics", {}
),
"total_cost": cost_data.get("total_cost", 0),
"cost_breakdown": cost_data.get("cost_breakdown", {}),
}
logger.info(
f"Confidence source for {test_run_id}: "
f"{'Stickler' if stickler_avg_confidence is not None else 'Athena'} "
f"(value: {avg_confidence})"
)
_invoke_mlflow_logger(
test_run_id, merged_metrics, config=test_run_config
)
return merged_metrics
else:
logger.warning(
f"Test execution aggregation returned empty metrics (document_count=0) for {test_run_id}, falling back to Athena"
)
else:
logger.warning(f"Test execution aggregation returned error: {payload}")
except Exception as e:
logger.error(
f"Test execution aggregation Lambda failed for {test_run_id}, falling back to Athena: {e}"
)
else:
logger.info(
"TEST_EXECUTION_AGGREGATION_FUNCTION_ARN not set, using Athena aggregation"
)
# Fallback to Athena-based aggregation
logger.info(f"Using Athena aggregation for test run {test_run_id}")
evaluation_metrics = _get_evaluation_metrics_from_athena(test_run_id)
cost_data = _get_cost_data_from_athena(test_run_id)
athena_result = {
"overall_accuracy": evaluation_metrics.get("overall_accuracy"),
"weighted_overall_scores": evaluation_metrics.get(
"weighted_overall_scores", {}
),
"average_confidence": evaluation_metrics.get("average_confidence"),
"accuracy_breakdown": evaluation_metrics.get("accuracy_breakdown", {}),
"split_classification_metrics": evaluation_metrics.get(
"split_classification_metrics", {}
),
"total_cost": cost_data.get("total_cost", 0),
"cost_breakdown": cost_data.get("cost_breakdown", {}),
}
_invoke_mlflow_logger(test_run_id, athena_result, config=test_run_config)
return athena_result
def _get_test_run_config(test_run_id):
"""Get test run configuration from metadata record"""
table = dynamodb.Table(os.environ["TRACKING_TABLE"]) # type: ignore[attr-defined]
response = table.get_item(Key={"PK": f"testrun#{test_run_id}", "SK": "metadata"})
config = response.get("Item", {}).get("Config", {})
# Convert DynamoDB Decimal objects to regular Python types for JSON serialization
def convert_decimals(obj):
if isinstance(obj, dict):
return {k: convert_decimals(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_decimals(v) for v in obj]
elif hasattr(obj, "__class__") and obj.__class__.__name__ == "Decimal":
# Convert Decimal to float or int
if obj % 1 == 0:
return int(obj)
else:
return float(obj)
else:
return obj
return convert_decimals(config)
def _build_config_comparison(configs):
"""Build configuration differences - compare actual Config structure"""
if not configs or len(configs) < 2:
return None
def get_nested_value(dictionary, path):
"""Get nested value from dictionary using dot notation path"""
keys = path.split(".")
current = dictionary
for key in keys:
if isinstance(current, dict) and key in current:
current = current[key]
elif isinstance(current, list) and key.isdigit():
# Handle array index access
index = int(key)
if 0 <= index < len(current):
current = current[index]
else:
return None
else:
return None
return current
def get_all_paths(dictionary, prefix=""):
"""Get all nested paths from dictionary"""
paths = []
ignored_fields = {
"UpdatedAt",
"Description",
"CreatedAt",
"IsActive",
"Configuration",
"version_name",
"classes",
}
for key, value in dictionary.items():
# Skip ignored metadata fields
if key in ignored_fields:
continue
current_path = f"{prefix}.{key}" if prefix else key
if isinstance(value, dict):
paths.extend(get_all_paths(value, current_path))
elif isinstance(value, list):
# Handle arrays by creating indexed paths for each element
for i, item in enumerate(value):
item_path = f"{current_path}.{i}"
if isinstance(item, dict):
paths.extend(get_all_paths(item, item_path))
else:
paths.append(item_path)
else:
paths.append(current_path)