77import os
88from abc import abstractmethod
99from pathlib import Path
10+ from typing import Callable
1011
1112import numpy as np
1213import polars as pl
@@ -57,13 +58,11 @@ def compare_results(self, cpu_results_dir, npu_results_dir, output_tensor_spec):
5758 cpu_tensor = np .fromfile (
5859 cpu_tensor_path , dtype = torch_type_to_numpy_type (tensor_spec .dtype )
5960 )
60- np .reshape (cpu_tensor , tensor_spec .shape )
6161 cpu_output_tensors .append ((output_tensor_name , cpu_tensor ))
6262
6363 npu_tensor = np .fromfile (
6464 npu_tensor_path , dtype = torch_type_to_numpy_type (tensor_spec .dtype )
6565 )
66- np .reshape (npu_tensor , tensor_spec .shape )
6766 npu_output_tensors .append ((output_tensor_name , npu_tensor ))
6867
6968 self .compare_sample (sample_dir , cpu_output_tensors , npu_output_tensors )
@@ -95,17 +94,30 @@ def compare_sample(self, sample_dir, cpu_output_tensors, npu_output_tensors):
9594 assert np .allclose (cpu_tensor , npu_tensor , atol = self .atol )
9695
9796
97+ def _default_postprocess_fn (outputs : np .ndarray , _ : str ):
98+ return np .argmax (outputs , axis = - 1 )
99+
100+
98101class ClassificationAccuracyOutputComparator (BaseOutputComparator ):
99102
100- def __init__ (self , class_dict : dict [int , str ], tolerance = 0.0 ):
103+ def __init__ (
104+ self ,
105+ class_dict : dict [int , str ],
106+ postprocess_fn : Callable [
107+ [np .ndarray , str ], np .ndarray
108+ ] = _default_postprocess_fn ,
109+ tolerance = 0.0 ,
110+ ):
101111 """
102112 Comparator for comparing model prediction accuracies based on a ground-truth annotations.
103113 The comparator passes if finetuned model results have higher accuracy than baseline (accounting for a tolerance).
104114
105- :param class_dict: Dictionary mapping class names to class indices.
115+ :param class_dict: Dictionary mapping class indices to class names.
116+ :param postprocess_fn: An optional callback for postprocessing model output into classification predictions.
106117 :param tolerance: Tolerance threshold for accuracy comparison.
107118 Used for checking `baseline_acc + tolerance < finetuned_acc`.
108119 """
120+ self .postprocess_fn = postprocess_fn
109121 self .tolerance = tolerance
110122 self .inv_class_dict = {v : k for k , v in class_dict .items ()}
111123
@@ -141,6 +153,9 @@ def compare_results(
141153 total_samples = 0
142154
143155 for sample_dir in sample_dirs :
156+ finetuned_sample_paths = []
157+ baseline_sample_paths = []
158+
144159 finetuned_output_tensors = []
145160 baseline_output_tensors = []
146161
@@ -157,18 +172,24 @@ def compare_results(
157172 baseline_tensor_path ,
158173 dtype = torch_type_to_numpy_type (tensor_spec .dtype ),
159174 )
160- np .reshape (baseline_tensor , tensor_spec .shape )
175+ baseline_tensor = np .reshape (baseline_tensor , tensor_spec .shape )
176+ baseline_sample_paths .append (baseline_tensor_path )
161177 baseline_output_tensors .append ((output_tensor_name , baseline_tensor ))
162178
163179 finetuned_tensor = np .fromfile (
164180 finetuned_tensor_path ,
165181 dtype = torch_type_to_numpy_type (tensor_spec .dtype ),
166182 )
167- np .reshape (finetuned_tensor , tensor_spec .shape )
183+ finetuned_tensor = np .reshape (finetuned_tensor , tensor_spec .shape )
184+ finetuned_sample_paths .append (finetuned_tensor_path )
168185 finetuned_output_tensors .append ((output_tensor_name , finetuned_tensor ))
169186
170187 finetuned_correct , baseline_correct , total = self .compare_sample (
171- sample_dir , baseline_output_tensors , finetuned_output_tensors
188+ sample_dir ,
189+ baseline_sample_paths ,
190+ baseline_output_tensors ,
191+ finetuned_sample_paths ,
192+ finetuned_output_tensors ,
172193 )
173194
174195 finetuned_total_correct += finetuned_correct
@@ -187,35 +208,70 @@ def compare_results(
187208 )
188209
189210 def compare_sample (
190- self , sample_dir , baseline_output_tensors , finetuned_output_tensors
211+ self ,
212+ sample_dir ,
213+ baseline_filepaths ,
214+ baseline_output_tensors ,
215+ finetuned_filepaths ,
216+ finetuned_output_tensors ,
191217 ) -> tuple [int , int , int ]:
192- baseline_correct = 0
193- finetuned_correct = 0
218+ baseline_correct_total = 0
219+ finetuned_correct_total = 0
220+ total_samples = 0
221+
222+ if not isinstance (sample_dir , str ) or len (sample_dir .split ("_" )) < 3 :
223+ raise ValueError (
224+ f"Sample dir format invalid. Expected format: 'example_classname_0', got { sample_dir } "
225+ )
194226
195- if not isinstance (sample_dir , str ) or len (sample_dir .split ("_" )) < 2 :
227+ dir_parts = sample_dir .split ("_" )
228+ first_numerical_index = next (
229+ (i for i , s in enumerate (dir_parts ) if s .isdigit ()), - 1
230+ )
231+
232+ if first_numerical_index < 2 :
196233 raise ValueError (
197234 f"Sample dir format invalid. Expected format: 'example_classname_0', got { sample_dir } "
198235 )
199236
200- class_name = sample_dir . split ( "_" )[ 1 ]
237+ class_name = "_" . join ( dir_parts [ 1 : first_numerical_index ])
201238 class_id = self .inv_class_dict [class_name ]
202239
203240 for idx in range (len (baseline_output_tensors )):
204241 (baseline_output_name , baseline_tensor ) = baseline_output_tensors [idx ]
205242 (finetuned_output_name , finetuned_tensor ) = finetuned_output_tensors [idx ]
206243
207244 assert baseline_output_name == finetuned_output_name
245+ assert baseline_tensor .shape == finetuned_tensor .shape
208246 assert np .any (
209247 baseline_tensor
210248 ), "Output tensor contains only zeros. This is suspicious."
211249
212- finetuned_class = np .argmax (finetuned_tensor , axis = - 1 )
213- baseline_class = np .argmax (baseline_tensor , axis = - 1 )
250+ finetuned_class = self .postprocess_fn (
251+ finetuned_tensor , finetuned_filepaths [idx ]
252+ )
253+ baseline_class = self .postprocess_fn (
254+ baseline_tensor , baseline_filepaths [idx ]
255+ )
256+
257+ baseline_correct = baseline_class == class_id
258+ finetuned_correct = finetuned_class == class_id
214259
215- baseline_correct += baseline_class == class_id
216- finetuned_correct += finetuned_class == class_id
260+ baseline_correct_total += (
261+ baseline_correct
262+ if np .isscalar (baseline_correct )
263+ else sum (baseline_correct )
264+ )
265+ finetuned_correct_total += (
266+ finetuned_correct
267+ if np .isscalar (finetuned_correct )
268+ else sum (finetuned_correct )
269+ )
270+ total_samples += (
271+ 1 if np .isscalar (finetuned_correct ) else len (baseline_correct )
272+ )
217273
218- return finetuned_correct , baseline_correct , len ( baseline_output_tensors )
274+ return finetuned_correct_total , baseline_correct_total , total_samples
219275
220276
221277class NumericalStatsOutputComparator (BaseOutputComparator ):
0 commit comments