diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index dbb6acf14c..3b52f30a51 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -67,6 +67,7 @@ method_configs: Dict[str, Union[TrainerConfig, ExternalMethodDummyTrainerConfig]] = {} descriptions = { "nerfacto": "Recommended real-time model tuned for real captures. This model will be continually updated.", + "nerfacto-psf": "Nerfacto with PSF-guided hashgrid selection (no b/grid sweep needed).", "nerfacto-huge": "Larger version of Nerfacto with higher quality.", "depth-nerfacto": "Nerfacto with depth supervision.", "instant-ngp": "Implementation of Instant-NGP. Recommended real-time model for unbounded scenes.", @@ -120,6 +121,45 @@ vis="viewer", ) +method_configs["nerfacto-psf"] = TrainerConfig( + method_name="nerfacto-psf", + steps_per_eval_batch=500, + steps_per_save=2000, + max_num_iterations=30000, + mixed_precision=True, + pipeline=VanillaPipelineConfig( + datamanager=ParallelDataManagerConfig( + dataparser=NerfstudioDataParserConfig(), + train_num_rays_per_batch=4096, + eval_num_rays_per_batch=4096, + ), + model=NerfactoModelConfig( + eval_num_rays_per_chunk=1 << 15, + average_init_density=0.01, + psf_guided_hashgrid=True, + psf_target_fwhm_pixels=0.5, + psf_broadening=3.0, + camera_optimizer=CameraOptimizerConfig(mode="SO3xR3"), + ), + ), + optimizers={ + "proposal_networks": { + "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000), + }, + "fields": { + "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000), + }, + "camera_opt": { + "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=5000), + }, + }, + viewer=ViewerConfig(num_rays_per_chunk=1 << 15), + vis="viewer", +) + method_configs["nerfacto-big"] = TrainerConfig( method_name="nerfacto", steps_per_eval_batch=500, diff --git a/nerfstudio/models/instant_ngp.py b/nerfstudio/models/instant_ngp.py index 6686b556cd..6797adcc97 100644 --- a/nerfstudio/models/instant_ngp.py +++ b/nerfstudio/models/instant_ngp.py @@ -35,6 +35,8 @@ from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, RGBRenderer from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils import colormaps +from nerfstudio.utils.hashgrid_psf import compute_psf_guided_hashgrid_params +from nerfstudio.utils.rich_utils import CONSOLE @dataclass @@ -53,10 +55,20 @@ class InstantNGPModelConfig(ModelConfig): """Resolution of the grid used for the field.""" grid_levels: int = 4 """Levels of the grid used for the field.""" + num_levels: int = 16 + """Number of levels of the hashgrid encoding.""" + base_res: int = 16 + """Resolution of the base level of the hashgrid encoding.""" max_res: int = 2048 """Maximum resolution of the hashmap for the base mlp.""" log2_hashmap_size: int = 19 """Size of the hashmap for the base mlp""" + psf_guided_hashgrid: bool = False + """If True, solve hash growth factor from PSF target FWHM and dataset scale instead of sweeping.""" + psf_target_fwhm_pixels: float = 0.5 + """Target empirical PSF FWHM measured in image pixels.""" + psf_broadening: float = 3.0 + """Empirical broadening factor used to map idealized to observed PSF width.""" alpha_thre: float = 0.01 """Threshold for opacity skipping.""" cone_angle: float = 0.004 @@ -97,6 +109,36 @@ def populate_modules(self): """Set the fields and modules.""" super().populate_modules() + active_max_res = int(self.config.max_res) + if self.config.psf_guided_hashgrid: + train_cameras = self.kwargs.get("train_cameras") + if train_cameras is not None and not hasattr(train_cameras, "camera_to_worlds"): + train_cameras = None + psf_params = compute_psf_guided_hashgrid_params( + train_cameras=train_cameras, + scene_aabb=self.scene_box.aabb, + base_resolution=self.config.base_res, + num_levels=self.config.num_levels, + target_fwhm_pixels=self.config.psf_target_fwhm_pixels, + dim=3, + broadening=self.config.psf_broadening, + ) + if psf_params is None: + CONSOLE.print( + "[bold yellow]PSF-guided hashgrid could not compute parameters; " + f"falling back to max_res={active_max_res}[/bold yellow]" + ) + else: + active_max_res = int(psf_params["max_res"]) + self.config.max_res = active_max_res + growth_factor = float(psf_params["growth_factor"]) + source = "scene-scale fallback" if psf_params["used_scene_scale_fallback"] > 0.5 else "camera scale" + CONSOLE.print( + "[green]PSF-guided hashgrid[/green] " + f"b={growth_factor:.4f} max_res={active_max_res} " + f"target_fwhm_world={psf_params['target_fwhm_world']:.4e} ({source})" + ) + if self.config.disable_scene_contraction: scene_contraction = None else: @@ -106,8 +148,10 @@ def populate_modules(self): aabb=self.scene_box.aabb, appearance_embedding_dim=0 if self.config.use_appearance_embedding else 32, num_images=self.num_train_data, + num_levels=self.config.num_levels, + base_res=self.config.base_res, log2_hashmap_size=self.config.log2_hashmap_size, - max_res=self.config.max_res, + max_res=active_max_res, spatial_distortion=scene_contraction, ) diff --git a/nerfstudio/models/nerfacto.py b/nerfstudio/models/nerfacto.py index bfccfd8797..6e3903b330 100644 --- a/nerfstudio/models/nerfacto.py +++ b/nerfstudio/models/nerfacto.py @@ -46,6 +46,8 @@ from nerfstudio.model_components.shaders import NormalsShader from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils import colormaps +from nerfstudio.utils.hashgrid_psf import compute_psf_guided_hashgrid_params, growth_factor_to_max_res +from nerfstudio.utils.rich_utils import CONSOLE @dataclass @@ -75,6 +77,14 @@ class NerfactoModelConfig(ModelConfig): """Size of the hashmap for the base mlp""" features_per_level: int = 2 """How many hashgrid features per level""" + psf_guided_hashgrid: bool = False + """If True, solve hash growth factor from PSF target FWHM and dataset scale instead of sweeping.""" + psf_target_fwhm_pixels: float = 0.5 + """Target empirical PSF FWHM measured in image pixels.""" + psf_broadening: float = 3.0 + """Empirical broadening factor used to map idealized to observed PSF width.""" + psf_adjust_proposal_nets: bool = True + """If True, apply the solved growth factor to proposal hash pyramids as well.""" num_proposal_samples_per_ray: Tuple[int, ...] = (256, 96) """Number of samples per ray for each proposal network.""" num_nerf_samples_per_ray: int = 48 @@ -145,6 +155,46 @@ def populate_modules(self): """Set the fields and modules.""" super().populate_modules() + active_max_res = int(self.config.max_res) + proposal_net_args_list = [dict(args) for args in self.config.proposal_net_args_list] + if self.config.psf_guided_hashgrid: + train_cameras = self.kwargs.get("train_cameras") + if train_cameras is not None and not hasattr(train_cameras, "camera_to_worlds"): + train_cameras = None + psf_params = compute_psf_guided_hashgrid_params( + train_cameras=train_cameras, + scene_aabb=self.scene_box.aabb, + base_resolution=self.config.base_res, + num_levels=self.config.num_levels, + target_fwhm_pixels=self.config.psf_target_fwhm_pixels, + dim=3, + broadening=self.config.psf_broadening, + ) + if psf_params is None: + CONSOLE.print( + "[bold yellow]PSF-guided hashgrid could not compute parameters; " + f"falling back to max_res={active_max_res}[/bold yellow]" + ) + else: + growth_factor = float(psf_params["growth_factor"]) + active_max_res = int(psf_params["max_res"]) + self.config.max_res = active_max_res + if self.config.psf_adjust_proposal_nets: + for proposal_args in proposal_net_args_list: + proposal_levels = int(proposal_args.get("num_levels", self.config.num_levels)) + proposal_base_res = int(proposal_args.get("base_res", self.config.base_res)) + proposal_args["max_res"] = growth_factor_to_max_res( + base_resolution=proposal_base_res, + num_levels=proposal_levels, + growth_factor=growth_factor, + ) + source = "scene-scale fallback" if psf_params["used_scene_scale_fallback"] > 0.5 else "camera scale" + CONSOLE.print( + "[green]PSF-guided hashgrid[/green] " + f"b={growth_factor:.4f} max_res={active_max_res} " + f"target_fwhm_world={psf_params['target_fwhm_world']:.4e} ({source})" + ) + if self.config.disable_scene_contraction: scene_contraction = None else: @@ -157,7 +207,7 @@ def populate_modules(self): self.scene_box.aabb, hidden_dim=self.config.hidden_dim, num_levels=self.config.num_levels, - max_res=self.config.max_res, + max_res=active_max_res, base_res=self.config.base_res, features_per_level=self.config.features_per_level, log2_hashmap_size=self.config.log2_hashmap_size, @@ -180,8 +230,8 @@ def populate_modules(self): # Build the proposal network(s) self.proposal_networks = torch.nn.ModuleList() if self.config.use_same_proposal_network: - assert len(self.config.proposal_net_args_list) == 1, "Only one proposal network is allowed." - prop_net_args = self.config.proposal_net_args_list[0] + assert len(proposal_net_args_list) == 1, "Only one proposal network is allowed." + prop_net_args = proposal_net_args_list[0] network = HashMLPDensityField( self.scene_box.aabb, spatial_distortion=scene_contraction, @@ -193,7 +243,7 @@ def populate_modules(self): self.density_fns.extend([network.density_fn for _ in range(num_prop_nets)]) else: for i in range(num_prop_nets): - prop_net_args = self.config.proposal_net_args_list[min(i, len(self.config.proposal_net_args_list) - 1)] + prop_net_args = proposal_net_args_list[min(i, len(proposal_net_args_list) - 1)] network = HashMLPDensityField( self.scene_box.aabb, spatial_distortion=scene_contraction, diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 00df5f71b7..34f174c9a2 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -270,6 +270,7 @@ def __init__( scene_box=self.datamanager.train_dataset.scene_box, num_train_data=len(self.datamanager.train_dataset), metadata=self.datamanager.train_dataset.metadata, + train_cameras=self.datamanager.train_dataset.cameras, device=device, grad_scaler=grad_scaler, seed_points=seed_pts, diff --git a/nerfstudio/utils/hashgrid_psf.py b/nerfstudio/utils/hashgrid_psf.py new file mode 100644 index 0000000000..7a05c3486f --- /dev/null +++ b/nerfstudio/utils/hashgrid_psf.py @@ -0,0 +1,193 @@ +"""PSF-guided helpers for hashgrid parameter selection.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +import numpy as np +import torch + +_PERSPECTIVE_CAMERA_TYPE = 1 + + +def _get_resolutions(base_resolution: int, growth_factor: float, num_levels: int) -> np.ndarray: + """Return per-level hash resolutions N_l = N_min * b^l.""" + levels = np.arange(num_levels, dtype=np.float64) + return float(base_resolution) * (float(growth_factor) ** levels) + + +def calculate_axis_fwhm( + base_resolution: int, + growth_factor: float, + num_levels: int, + *, + dim: int = 3, +) -> float: + """Estimate axis FWHM using the revised weighted-kernel approximation.""" + resolutions = _get_resolutions(base_resolution=base_resolution, growth_factor=growth_factor, num_levels=num_levels) + sum_dim = float(np.sum(resolutions**dim)) + 1e-20 + sum_dim_p1 = float(np.sum(resolutions ** (dim + 1))) + 1e-20 + return sum_dim / sum_dim_p1 + + +def solve_growth_factor_for_target_fwhm( + base_resolution: int, + num_levels: int, + target_fwhm_world: float, + *, + dim: int = 3, + broadening: float = 1.0, + tol: float = 1e-6, + max_iter: int = 64, +) -> float: + """Binary-search the hashgrid growth factor `b` for a target empirical FWHM.""" + if target_fwhm_world <= 0.0: + raise ValueError("target_fwhm_world must be > 0") + if base_resolution <= 0: + raise ValueError("base_resolution must be > 0") + if num_levels < 1: + raise ValueError("num_levels must be >= 1") + if dim < 1: + raise ValueError("dim must be >= 1") + if broadening <= 0.0: + raise ValueError("broadening must be > 0") + + # Empirical PSF is wider than idealized theory by this broadening factor. + target_fwhm_world = float(target_fwhm_world) / float(broadening) + + def fwhm_from_b(growth_factor: float) -> float: + return calculate_axis_fwhm( + base_resolution=base_resolution, + growth_factor=growth_factor, + num_levels=num_levels, + dim=dim, + ) + + lo = 1.0 + 1e-6 + hi = lo * 2.0 + f_lo = fwhm_from_b(lo) + if target_fwhm_world >= f_lo: + return lo + + f_hi = fwhm_from_b(hi) + growth_guard = 0 + while f_hi > target_fwhm_world and growth_guard < 64: + hi *= 1.5 + f_hi = fwhm_from_b(hi) + growth_guard += 1 + if hi > 1e6: + raise ValueError("Unable to bracket target FWHM; try different inputs.") + + for _ in range(max_iter): + mid = 0.5 * (lo + hi) + f_mid = fwhm_from_b(mid) + if abs(f_mid - target_fwhm_world) < tol: + return mid + if f_mid > target_fwhm_world: + lo = mid + else: + hi = mid + + return 0.5 * (lo + hi) + + +def growth_factor_to_max_res(base_resolution: int, num_levels: int, growth_factor: float) -> int: + """Convert growth factor `b` to integer max resolution used by Nerfstudio configs.""" + if num_levels <= 1: + return int(base_resolution) + max_res = float(base_resolution) * (float(growth_factor) ** float(num_levels - 1)) + return int(max(base_resolution, round(max_res))) + + +def estimate_pixel_world_size(cameras: Optional[Any]) -> Optional[float]: + """Estimate world-units per pixel from perspective cameras.""" + if cameras is None: + return None + if cameras.camera_to_worlds.numel() == 0: + return None + + required_attrs = ("camera_to_worlds", "camera_type", "fx", "fy") + if not all(hasattr(cameras, attr) for attr in required_attrs): + return None + + camera_type = cameras.camera_type.reshape(-1) + perspective_mask = camera_type == _PERSPECTIVE_CAMERA_TYPE + if not torch.any(perspective_mask): + return None + + origins = cameras.camera_to_worlds[..., :3, 3].reshape(-1, 3).to(dtype=torch.float64) + fx = cameras.fx.reshape(-1).to(dtype=torch.float64) + fy = cameras.fy.reshape(-1).to(dtype=torch.float64) + base_valid = perspective_mask & torch.isfinite(fx) & torch.isfinite(fy) & (fx > 0.0) & (fy > 0.0) + if not torch.any(base_valid): + return None + + origins = origins[base_valid] + fx = fx[base_valid] + fy = fy[base_valid] + distances = torch.linalg.norm(origins, dim=-1) + valid = torch.isfinite(distances) & (distances > 0.0) + if not torch.any(valid): + return None + + # Pinhole approximation: one pixel subtends z/f focal world units at depth z. + distances = distances[valid] + px_x = distances / fx[valid] + px_y = distances / fy[valid] + pixel_world = 0.5 * (px_x + px_y) + if pixel_world.numel() == 0: + return None + return float(torch.median(pixel_world).item()) + + +def estimate_scene_scale_from_aabb(aabb: torch.Tensor) -> float: + """Fallback scene scale based on longest aabb side in world units.""" + if aabb.numel() != 6: + return 1.0 + extents = (aabb[1] - aabb[0]).to(dtype=torch.float64) + max_extent = float(torch.max(torch.abs(extents)).item()) + return max(max_extent, 1e-6) + + +def compute_psf_guided_hashgrid_params( + *, + train_cameras: Optional[Any], + scene_aabb: torch.Tensor, + base_resolution: int, + num_levels: int, + target_fwhm_pixels: float, + dim: int = 3, + broadening: float = 3.0, +) -> Optional[Dict[str, float]]: + """Compute PSF-guided hashgrid growth/max-res from dataset scale and image resolution.""" + if target_fwhm_pixels <= 0.0: + return None + + pixel_world = estimate_pixel_world_size(train_cameras) if train_cameras is not None else None + used_scene_scale_fallback = pixel_world is None + if pixel_world is None: + pixel_world = estimate_scene_scale_from_aabb(scene_aabb) + + target_fwhm_world = float(target_fwhm_pixels) * float(pixel_world) + if target_fwhm_world <= 0.0: + return None + + growth_factor = solve_growth_factor_for_target_fwhm( + base_resolution=base_resolution, + num_levels=num_levels, + target_fwhm_world=target_fwhm_world, + dim=dim, + broadening=broadening, + ) + max_res = growth_factor_to_max_res( + base_resolution=base_resolution, + num_levels=num_levels, + growth_factor=growth_factor, + ) + return { + "growth_factor": float(growth_factor), + "max_res": float(max_res), + "target_fwhm_world": float(target_fwhm_world), + "pixel_world": float(pixel_world), + "used_scene_scale_fallback": float(1.0 if used_scene_scale_fallback else 0.0), + } diff --git a/tests/utils/test_hashgrid_psf.py b/tests/utils/test_hashgrid_psf.py new file mode 100644 index 0000000000..1c4c16ac7e --- /dev/null +++ b/tests/utils/test_hashgrid_psf.py @@ -0,0 +1,94 @@ +"""Tests for PSF-guided hashgrid parameter selection helpers.""" + +from types import SimpleNamespace + +import pytest +import torch + +from nerfstudio.utils.hashgrid_psf import ( + calculate_axis_fwhm, + compute_psf_guided_hashgrid_params, + estimate_pixel_world_size, + growth_factor_to_max_res, + solve_growth_factor_for_target_fwhm, +) + + +def _make_cameras(distance: float = 2.0, focal: float = 500.0): + camera_to_worlds = torch.eye(4, dtype=torch.float32)[:3, :].unsqueeze(0) + camera_to_worlds[..., 2, 3] = distance + return SimpleNamespace( + camera_to_worlds=camera_to_worlds, + camera_type=torch.ones((1, 1), dtype=torch.int64), + fx=torch.tensor([[focal]], dtype=torch.float32), + fy=torch.tensor([[focal]], dtype=torch.float32), + ) + + +def test_solve_growth_factor_for_target_fwhm(): + base_resolution = 16 + num_levels = 16 + target_fwhm_world = 0.01 + broadening = 3.0 + + growth_factor = solve_growth_factor_for_target_fwhm( + base_resolution=base_resolution, + num_levels=num_levels, + target_fwhm_world=target_fwhm_world, + dim=3, + broadening=broadening, + ) + solved_fwhm = ( + calculate_axis_fwhm( + base_resolution=base_resolution, + growth_factor=growth_factor, + num_levels=num_levels, + dim=3, + ) + * broadening + ) + assert growth_factor > 1.0 + assert solved_fwhm == pytest.approx(target_fwhm_world, rel=1e-3, abs=1e-6) + + +def test_growth_factor_to_max_res(): + max_res = growth_factor_to_max_res(base_resolution=16, num_levels=16, growth_factor=1.38) + assert max_res >= 16 + + +def test_estimate_pixel_world_size(): + cameras = _make_cameras(distance=2.0, focal=500.0) + pixel_world = estimate_pixel_world_size(cameras) + assert pixel_world is not None + assert pixel_world == pytest.approx(2.0 / 500.0, rel=1e-5, abs=1e-8) + + +def test_compute_psf_guided_hashgrid_params(): + scene_aabb = torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]], dtype=torch.float32) + cameras = _make_cameras(distance=2.0, focal=500.0) + + params = compute_psf_guided_hashgrid_params( + train_cameras=cameras, + scene_aabb=scene_aabb, + base_resolution=16, + num_levels=16, + target_fwhm_pixels=0.5, + dim=3, + broadening=3.0, + ) + assert params is not None + assert params["growth_factor"] > 1.0 + assert int(params["max_res"]) >= 16 + assert params["used_scene_scale_fallback"] == 0.0 + + fallback_params = compute_psf_guided_hashgrid_params( + train_cameras=None, + scene_aabb=scene_aabb, + base_resolution=16, + num_levels=16, + target_fwhm_pixels=0.5, + dim=3, + broadening=3.0, + ) + assert fallback_params is not None + assert fallback_params["used_scene_scale_fallback"] == 1.0