diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 475aa672e..a696bc6fa 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -92,16 +92,22 @@ def _update_param_with_optimizer( f"Got requires_grad={param.requires_grad}" ) continue + optimizer = optimizers[name] + # Rebind refs to `new_param` before allocating new Adam state, so + # the old parameter tensor is released during optimizer_fn calls. + param_state = optimizer.state.pop(param, {}) for i in range(len(optimizer.param_groups)): - param_state = optimizer.state[param] - del optimizer.state[param] - for key in param_state.keys(): - if key != "step": - v = param_state[key] - param_state[key] = optimizer_fn(key, v) optimizer.param_groups[i]["params"] = [new_param] - optimizer.state[new_param] = param_state + optimizer.state[new_param] = param_state + del param + + for key in list(param_state.keys()): + if key == "step": + continue + old = param_state[key] + param_state[key] = optimizer_fn(key, old) + del old @torch.no_grad()