Skip to content

Commit ee0d898

Browse files
committed
Added stats to model scores script output
1 parent 9e7c0db commit ee0d898

1 file changed

Lines changed: 191 additions & 33 deletions

File tree

tests/model-metrics/test-all-models.py

Lines changed: 191 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -86,34 +86,55 @@ def get_track_results(model_name, track_name):
8686
return None
8787

8888

89+
def get_track_duration(track_path):
90+
"""Get the duration of a track in minutes"""
91+
try:
92+
mixture_path = os.path.join(track_path, "mixture.wav")
93+
info = sf.info(mixture_path)
94+
return info.duration / 60.0 # Convert seconds to minutes
95+
except Exception as e:
96+
logger.error(f"Error getting track duration: {str(e)}")
97+
return 0.0
98+
99+
89100
def evaluate_track(track_name, track_path, test_model, mus_db):
90101
"""Evaluate a single track using a specific model"""
91102
logger.info(f"Evaluating track: {track_name} with model: {test_model}")
92103

93-
# Set output directory for this separation
94-
output_dir = os.path.join(RESULTS_PATH, test_model, track_name)
95-
os.makedirs(output_dir, exist_ok=True)
104+
# Get track duration in minutes
105+
track_duration_minutes = get_track_duration(track_path)
106+
logger.info(f"Track duration: {track_duration_minutes:.2f} minutes")
96107

97-
# Check if evaluation results already exist
98-
results_file = os.path.join(output_dir, "museval-results.json")
99-
if os.path.exists(results_file):
100-
logger.info("Found existing evaluation results, loading from file...")
101-
with open(results_file) as f:
102-
json_data = json.load(f)
108+
# Check if evaluation results already exist in combined file
109+
museval_results = load_combined_results()
110+
if test_model in museval_results and track_name in museval_results[test_model]:
111+
logger.info("Found existing evaluation results in combined file...")
112+
track_data = museval_results[test_model][track_name]
103113
scores = museval.TrackStore(track_name)
104-
scores.scores = json_data
114+
scores.scores = track_data
105115
else:
106116
# Expanded stem mapping to include "no-stem" outputs
107117
stem_mapping = {"Vocals": "vocals", "Instrumental": "instrumental", "Drums": "drums", "Bass": "bass", "Other": "other", "No Drums": "nodrums", "No Bass": "nobass", "No Other": "noother"}
108118

109119
# Create a temporary directory for separation files
110120
with tempfile.TemporaryDirectory() as temp_dir:
111-
# Perform separation if needed
121+
logger.info(f"Using temporary directory: {temp_dir}")
122+
123+
# Measure separation time
124+
start_time = time.time()
125+
126+
# Perform separation
112127
logger.info("Performing separation...")
113128
separator = Separator(output_dir=temp_dir)
114129
separator.load_model(model_filename=test_model)
115130
separator.separate(os.path.join(track_path, "mixture.wav"), custom_output_names=stem_mapping)
116131

132+
# Calculate processing time
133+
processing_time = time.time() - start_time
134+
seconds_per_minute = processing_time / track_duration_minutes if track_duration_minutes > 0 else 0
135+
logger.info(f"Separation completed in {processing_time:.2f} seconds")
136+
logger.info(f"Processing speed: {seconds_per_minute:.2f} seconds per minute of audio")
137+
117138
# Check which stems were actually created and pair them appropriately
118139
available_stems = {}
119140
stem_pairs = {"drums": "nodrums", "bass": "nobass", "other": "noother", "vocals": "instrumental"}
@@ -150,25 +171,13 @@ def evaluate_track(track_name, track_path, test_model, mus_db):
150171

151172
# Evaluate using museval
152173
logger.info(f"Evaluating stems: {list(estimates.keys())}")
153-
# Use the temp directory for intermediate results
154174
scores = museval.eval_mus_track(track, estimates, output_dir=temp_dir, mode="v4")
155175

156-
# Move only the final results file to the permanent location
157-
os.makedirs(output_dir, exist_ok=True)
158-
test_results = os.path.join(temp_dir, "test", f"{track_name}.json")
159-
train_results = os.path.join(temp_dir, "train", f"{track_name}.json")
160-
161-
if os.path.exists(test_results):
162-
with open(test_results, "r") as f:
163-
results_data = json.load(f)
164-
with open(results_file, "w") as f:
165-
json.dump(results_data, f)
166-
elif os.path.exists(train_results):
167-
with open(train_results, "r") as f:
168-
results_data = json.load(f)
169-
with open(results_file, "w") as f:
170-
json.dump(results_data, f)
171-
# No need to remove directories as the temp directory will be automatically cleaned up
176+
# Update the combined results file with the new evaluation
177+
if test_model not in museval_results:
178+
museval_results[test_model] = {}
179+
museval_results[test_model][track_name] = scores.scores
180+
save_combined_results(museval_results)
172181

173182
# Calculate aggregate scores for available stems
174183
results_store = museval.EvalStore()
@@ -189,6 +198,11 @@ def evaluate_track(track_name, track_path, test_model, mus_db):
189198
except KeyError:
190199
continue
191200

201+
# Add the seconds_per_minute_m3 metric if it was calculated
202+
if "processing_time" in locals() and track_duration_minutes > 0:
203+
seconds_per_minute = processing_time / track_duration_minutes
204+
model_results["scores"]["seconds_per_minute_m3"] = round(seconds_per_minute, 1)
205+
192206
return scores, model_results if model_results["scores"] else None
193207

194208

@@ -211,22 +225,32 @@ def calculate_median_scores(track_scores):
211225
"drums": {"SDR": [], "SIR": [], "SAR": [], "ISR": []},
212226
"bass": {"SDR": [], "SIR": [], "SAR": [], "ISR": []},
213227
"instrumental": {"SDR": [], "SIR": [], "SAR": [], "ISR": []},
228+
"seconds_per_minute_m3": [],
214229
}
215230

216231
# Collect all scores for each stem and metric
217232
for track_score in track_scores:
218233
if track_score is not None and "scores" in track_score:
234+
# Process audio quality metrics
219235
for stem, metrics in track_score["scores"].items():
220-
if stem in stem_metrics:
236+
if stem in stem_metrics and stem != "seconds_per_minute_m3":
221237
for metric, value in metrics.items():
222238
stem_metrics[stem][metric].append(value)
223239

240+
# Process speed metric separately
241+
if "seconds_per_minute_m3" in track_score["scores"]:
242+
stem_metrics["seconds_per_minute_m3"].append(track_score["scores"]["seconds_per_minute_m3"])
243+
224244
# Calculate medians for each stem and metric
225245
median_scores = {}
226246
for stem, metrics in stem_metrics.items():
227-
if any(metrics.values()): # Only include stems that have scores
247+
if stem != "seconds_per_minute_m3" and any(metrics.values()): # Only include stems that have scores
228248
median_scores[stem] = {metric: float(f"{np.median(values):.6g}") for metric, values in metrics.items() if values} # Only include metrics that have values
229249

250+
# Add median speed metric if available
251+
if stem_metrics["seconds_per_minute_m3"]:
252+
median_scores["seconds_per_minute_m3"] = round(np.median(stem_metrics["seconds_per_minute_m3"]), 1)
253+
230254
return median_scores
231255

232256

@@ -328,6 +352,58 @@ def get_most_evaluated_tracks(museval_results, min_count=10):
328352
return [track for track, count in sorted_tracks if count >= min_count]
329353

330354

355+
def generate_summary_statistics(
356+
start_time, models_processed, tracks_processed, models_with_new_data, tracks_evaluated, total_processing_time, fastest_model=None, slowest_model=None, combined_results_path=None, is_dry_run=False
357+
):
358+
"""Generate a summary of the script's execution"""
359+
end_time = time.time()
360+
total_runtime = end_time - start_time
361+
362+
# Format the runtime
363+
hours, remainder = divmod(total_runtime, 3600)
364+
minutes, seconds = divmod(remainder, 60)
365+
runtime_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
366+
367+
# Build the summary
368+
summary = [
369+
"=" * 80,
370+
"DRY RUN SUMMARY - PREVIEW ONLY" if is_dry_run else "EXECUTION SUMMARY",
371+
"=" * 80,
372+
f"Total runtime: {runtime_str}",
373+
f"Models {'that would be' if is_dry_run else ''} processed: {models_processed}",
374+
f"Models {'that would receive' if is_dry_run else 'with'} new data: {len(models_with_new_data)}",
375+
f"Total tracks {'that would be' if is_dry_run else ''} evaluated: {tracks_evaluated}",
376+
f"Average tracks per model: {tracks_evaluated / len(models_with_new_data) if models_with_new_data else 0:.2f}",
377+
]
378+
379+
if fastest_model:
380+
summary.append(f"Fastest model: {fastest_model['name']} ({fastest_model['speed']:.2f} seconds per minute)")
381+
382+
if slowest_model:
383+
summary.append(f"Slowest model: {slowest_model['name']} ({slowest_model['speed']:.2f} seconds per minute)")
384+
385+
if total_processing_time > 0:
386+
summary.append(f"Total audio processing time: {total_processing_time:.2f} seconds")
387+
388+
if combined_results_path and os.path.exists(combined_results_path):
389+
file_size = os.path.getsize(combined_results_path) / (1024 * 1024) # Size in MB
390+
summary.append(f"Results file size: {file_size:.2f} MB")
391+
392+
# Add models with new data
393+
if models_with_new_data:
394+
summary.append(f"\nModels {'that would receive' if is_dry_run else 'with'} new evaluation data:")
395+
for model_name in models_with_new_data:
396+
summary.append(f"- {model_name}")
397+
398+
# Add dry run disclaimer if needed
399+
if is_dry_run:
400+
summary.append("\nNOTE: This is a dry run summary. No actual changes were made.")
401+
summary.append("Run without --dry-run to perform actual evaluations.")
402+
403+
summary.append("=" * 80)
404+
return "\n".join(summary)
405+
406+
331407
def main():
332408
# Add command line argument parsing for dry run mode
333409
parser = argparse.ArgumentParser(description="Run model evaluation on MUSDB18 dataset")
@@ -339,6 +415,14 @@ def main():
339415
# Track start time for progress reporting
340416
start_time = time.time()
341417

418+
# Statistics tracking
419+
models_processed = 0
420+
tracks_processed = 0
421+
models_with_new_data = set()
422+
total_processing_time = 0
423+
fastest_model = {"name": "", "speed": float("inf")} # Initialize with infinity for comparison
424+
slowest_model = {"name": "", "speed": 0} # Initialize with zero for comparison
425+
342426
# Create a results cache manager
343427
class ResultsCache:
344428
def __init__(self):
@@ -511,15 +595,56 @@ def log_with_time(message, level=logging.INFO):
511595
if args.dry_run:
512596
log_with_time(f"[DRY RUN] Would evaluate track {track_name} with model {model_filename}")
513597
tracks_processed += 1
598+
models_with_new_data.add(model_filename)
599+
600+
# Estimate processing time based on model type for dry run
601+
# This is a rough estimate - roformer models are typically slower
602+
estimated_speed = 30.0 # Default estimate: 30 seconds per minute
603+
if "roformer" in model_name.lower():
604+
estimated_speed = 45.0 # Roformer models are typically slower
605+
elif "umx" in model_name.lower():
606+
estimated_speed = 20.0 # UMX models are typically faster
607+
608+
# Update statistics with estimated values
609+
total_processing_time += estimated_speed
610+
611+
# Track fastest and slowest models based on estimates
612+
if estimated_speed < fastest_model["speed"]:
613+
fastest_model = {"name": model_name, "speed": estimated_speed}
614+
if estimated_speed > slowest_model["speed"]:
615+
slowest_model = {"name": model_name, "speed": estimated_speed}
616+
514617
continue
515618

516619
try:
517-
_, model_results = evaluate_track(track_name, track_path, model_filename, mus)
518-
if model_results:
620+
result = evaluate_track(track_name, track_path, model_filename, mus)
621+
622+
# Unpack the result safely
623+
if result and isinstance(result, tuple) and len(result) == 2:
624+
_, model_results = result
625+
else:
626+
model_results = None
627+
628+
# Process the results if they exist and are valid
629+
if model_results is not None and isinstance(model_results, dict):
519630
combined_results[model_filename]["track_scores"].append(model_results)
520631
tracks_processed += 1
632+
models_with_new_data.add(model_filename)
633+
634+
# Track processing time statistics - safely access nested dictionaries
635+
scores = model_results.get("scores", {})
636+
if isinstance(scores, dict):
637+
speed = scores.get("seconds_per_minute_m3")
638+
if speed is not None:
639+
total_processing_time += speed # Accumulate total processing time
640+
641+
# Track fastest and slowest models
642+
if speed < fastest_model["speed"]:
643+
fastest_model = {"name": model_name, "speed": speed}
644+
if speed > slowest_model["speed"]:
645+
slowest_model = {"name": model_name, "speed": speed}
521646
else:
522-
log_with_time(f"Skipping model {model_filename} for track {track_name} due to no evaluatable stems")
647+
log_with_time(f"Skipping model {model_filename} for track {track_name} due to no evaluatable stems or invalid results")
523648
except Exception as e:
524649
log_with_time(f"Error evaluating model {model_filename} with track {track_name}: {str(e)}", logging.ERROR)
525650
logger.exception(f"Exception details: ", exc_info=e)
@@ -572,10 +697,43 @@ def log_with_time(message, level=logging.INFO):
572697

573698
# Move to the next model
574699
model_idx += 1
700+
models_processed += 1
575701

576702
log_with_time("Evaluation complete")
577703
# Final disk space check
578704
check_disk_usage(RESULTS_PATH)
705+
706+
# Generate and display summary statistics
707+
# Reset fastest/slowest models if they weren't updated
708+
if fastest_model["speed"] == float("inf"):
709+
fastest_model = None
710+
if slowest_model["speed"] == 0:
711+
slowest_model = None
712+
713+
summary = generate_summary_statistics(
714+
start_time=start_time,
715+
models_processed=models_processed,
716+
tracks_processed=tracks_processed,
717+
models_with_new_data=models_with_new_data,
718+
tracks_evaluated=tracks_processed,
719+
total_processing_time=total_processing_time,
720+
fastest_model=fastest_model,
721+
slowest_model=slowest_model,
722+
combined_results_path=COMBINED_RESULTS_PATH,
723+
is_dry_run=args.dry_run,
724+
)
725+
726+
log_with_time("\n" + summary)
727+
728+
# Also write summary to a log file
729+
summary_filename = "dry_run_summary.log" if args.dry_run else "evaluation_summary.log"
730+
summary_log_path = os.path.join(os.path.dirname(COMBINED_RESULTS_PATH), summary_filename)
731+
with open(summary_log_path, "w") as f:
732+
f.write(f"{'Dry run' if args.dry_run else 'Evaluation'} completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
733+
f.write(summary)
734+
735+
log_with_time(f"Summary written to {summary_log_path}")
736+
579737
return 0
580738

581739

0 commit comments

Comments
 (0)