@@ -89,16 +89,24 @@ def validate_matrix_output(matrix_values: List[dict]) -> List[dict]:
8989def mark_eval_entries (matrix_values : List [dict ]) -> List [dict ]:
9090 """Mark entries that should run evaluation.
9191
92- For each unique (model, runner, isl, osl) combination:
92+ For each unique (model, runner, framework, precision, isl, osl) combination:
9393 - Mark highest TP with highest conc
9494 - Mark lowest TP with highest conc
9595 """
9696 from collections import defaultdict
9797
98- # Group entries by (model, runner, isl, osl)
98+ # Group entries by (model, runner, framework, precision, isl, osl)
99+ # This ensures we compare within the same configuration, not across different frameworks
99100 groups = defaultdict (list )
100101 for i , entry in enumerate (matrix_values ):
101- key = (entry [FIELD_MODEL ], entry [FIELD_RUNNER ], entry [FIELD_ISL ], entry [FIELD_OSL ])
102+ key = (
103+ entry [FIELD_MODEL ],
104+ entry [FIELD_RUNNER ],
105+ entry [FIELD_FRAMEWORK ],
106+ entry [FIELD_PRECISION ],
107+ entry [FIELD_ISL ],
108+ entry [FIELD_OSL ]
109+ )
102110 groups [key ].append ((i , entry ))
103111
104112 # For each group, find highest TP/highest conc and lowest TP/highest conc
0 commit comments