Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,74 @@
for callback in self.callbacks:
callback.run_callback_at_location(step=self.step, location=TrainingCallbackLocation.AFTER_TRAIN)

# export metrics to CSV
self._export_metrics_to_csv()

if not self.config.viewer.quit_on_train_completion:
self._train_complete_viewer()

def _export_metrics_to_csv(self) -> None:
"""Export final training metrics to CSV file.

This method is called at the end of training to export evaluation metrics
to a CSV file located at base_dir/final_metrics.csv.
"""
import csv
from pathlib import Path

# Define output path
output_path = Path(self.base_dir) / "final_metrics.csv"

# Check if we have an eval dataset
if not self.pipeline.datamanager.eval_dataset:
CONSOLE.log("[yellow]No evaluation dataset found, skipping metrics export.[/yellow]")
return

try:
# Get average evaluation metrics across all images
CONSOLE.log(f"[cyan]Calculating final evaluation metrics for export...[/cyan]")

Check failure on line 370 in nerfstudio/engine/trainer.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F541)

nerfstudio/engine/trainer.py:370:25: F541 f-string without any placeholders
metrics_dict = self.pipeline.get_average_eval_image_metrics(step=self.step)

# Convert metrics to a format suitable for CSV
# Handle nested dictionaries if any
flattened_metrics = {}
for key, value in metrics_dict.items():
if isinstance(value, dict):
for sub_key, sub_value in value.items():
flattened_metrics[f"{key}/{sub_key}"] = sub_value
else:
flattened_metrics[key] = value

# Write to CSV file
with open(output_path, 'w', newline='') as csvfile:
csv_writer = csv.writer(csvfile)

# Write header
csv_writer.writerow(['Metric', 'Value'])

# Write metrics
for metric_name, metric_value in flattened_metrics.items():
# Convert to float/int if possible, otherwise keep as string
try:
if isinstance(metric_value, (int, float)):
csv_writer.writerow([metric_name, metric_value])
elif hasattr(metric_value, 'item'):
# Handle torch.Tensor
csv_writer.writerow([metric_name, metric_value.item()])
else:
# Try to convert to float
csv_writer.writerow([metric_name, float(metric_value)])
except (ValueError, TypeError):
# If conversion fails, write as string
csv_writer.writerow([metric_name, str(metric_value)])

CONSOLE.log(f"[bold][green]Successfully exported metrics to: {output_path}[/bold][/green]")

except Exception as e:
CONSOLE.log(f"[bold][red]Failed to export metrics to CSV: {e}[/bold][/red]")
import traceback
traceback.print_exc()

@check_main_thread
def _check_viewer_warnings(self) -> None:
"""Helper to print out any warnings regarding the way the viewer/loggers are enabled"""
Expand Down
Loading