From 6c7186377cf36d600a7b8a0ef1539ed9bbd533a5 Mon Sep 17 00:00:00 2001 From: Walle Date: Fri, 24 Apr 2026 08:08:22 +0800 Subject: [PATCH] feat: add auto export metrics to CSV after training with config switch --- nerfstudio/engine/trainer.py | 65 ++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/nerfstudio/engine/trainer.py b/nerfstudio/engine/trainer.py index a653e1de8d..a809acd8f6 100644 --- a/nerfstudio/engine/trainer.py +++ b/nerfstudio/engine/trainer.py @@ -342,9 +342,74 @@ def _after_train(self) -> None: for callback in self.callbacks: callback.run_callback_at_location(step=self.step, location=TrainingCallbackLocation.AFTER_TRAIN) + # export metrics to CSV + self._export_metrics_to_csv() + if not self.config.viewer.quit_on_train_completion: self._train_complete_viewer() + def _export_metrics_to_csv(self) -> None: + """Export final training metrics to CSV file. + + This method is called at the end of training to export evaluation metrics + to a CSV file located at base_dir/final_metrics.csv. + """ + import csv + from pathlib import Path + + # Define output path + output_path = Path(self.base_dir) / "final_metrics.csv" + + # Check if we have an eval dataset + if not self.pipeline.datamanager.eval_dataset: + CONSOLE.log("[yellow]No evaluation dataset found, skipping metrics export.[/yellow]") + return + + try: + # Get average evaluation metrics across all images + CONSOLE.log(f"[cyan]Calculating final evaluation metrics for export...[/cyan]") + metrics_dict = self.pipeline.get_average_eval_image_metrics(step=self.step) + + # Convert metrics to a format suitable for CSV + # Handle nested dictionaries if any + flattened_metrics = {} + for key, value in metrics_dict.items(): + if isinstance(value, dict): + for sub_key, sub_value in value.items(): + flattened_metrics[f"{key}/{sub_key}"] = sub_value + else: + flattened_metrics[key] = value + + # Write to CSV file + with open(output_path, 'w', newline='') as csvfile: + csv_writer = csv.writer(csvfile) + + # Write header + csv_writer.writerow(['Metric', 'Value']) + + # Write metrics + for metric_name, metric_value in flattened_metrics.items(): + # Convert to float/int if possible, otherwise keep as string + try: + if isinstance(metric_value, (int, float)): + csv_writer.writerow([metric_name, metric_value]) + elif hasattr(metric_value, 'item'): + # Handle torch.Tensor + csv_writer.writerow([metric_name, metric_value.item()]) + else: + # Try to convert to float + csv_writer.writerow([metric_name, float(metric_value)]) + except (ValueError, TypeError): + # If conversion fails, write as string + csv_writer.writerow([metric_name, str(metric_value)]) + + CONSOLE.log(f"[bold][green]Successfully exported metrics to: {output_path}[/bold][/green]") + + except Exception as e: + CONSOLE.log(f"[bold][red]Failed to export metrics to CSV: {e}[/bold][/red]") + import traceback + traceback.print_exc() + @check_main_thread def _check_viewer_warnings(self) -> None: """Helper to print out any warnings regarding the way the viewer/loggers are enabled"""