From f2b3868013e57db46621d19eb6c48924b41793c9 Mon Sep 17 00:00:00 2001 From: Meraj Hashemizadeh Date: Sun, 7 Sep 2025 23:09:44 -0400 Subject: [PATCH 1/5] Make sparse gradients configurable in `IndexedMultiplier` --- src/cooper/multipliers/multipliers.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/cooper/multipliers/multipliers.py b/src/cooper/multipliers/multipliers.py index 51425056..c398ad37 100644 --- a/src/cooper/multipliers/multipliers.py +++ b/src/cooper/multipliers/multipliers.py @@ -137,6 +137,15 @@ def forward(self) -> torch.Tensor: class IndexedMultiplier(ExplicitMultiplier): r""":py:class:`~cooper.multipliers.ExplicitMultiplier` for indexed constraints which are evaluated only for a subset of constraints on every optimization step. + + Args: + num_constraints: Number of constraints associated with the multiplier. + init: Tensor used to initialize the multiplier values. If both ``init`` and + ``num_constraints`` are provided, ``init`` must have shape ``(num_constraints,)``. + device: Device for the multiplier. If ``None``, the device is inferred from the + ``init`` tensor or the default device. + dtype: Data type for the multiplier. Default is ``torch.float32``. + sparse: Whether to use sparse gradients during indexing. Default is True. """ expects_constraint_features = True @@ -147,8 +156,12 @@ def __init__( init: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, + *, + sparse: bool = True, ) -> None: super().__init__(num_constraints, init, device, dtype) + self.sparse = sparse + if self.weight.dim() == 1: # To use the forward call in F.embedding, we must reshape the weight to be a # 2-dim tensor @@ -171,7 +184,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: # TODO(gallego-posada): Document sparse gradients are expected for stateful # optimizers (having buffers) - multiplier_values = torch.nn.functional.embedding(indices, self.weight, sparse=True) + multiplier_values = torch.nn.functional.embedding(indices, self.weight, sparse=self.sparse) # Flatten multiplier values to 1D since Embedding works with 2D tensors. return torch.flatten(multiplier_values) From 3eba52c1390910e0602c51bafd04ad81e389cbb9 Mon Sep 17 00:00:00 2001 From: Meraj Hashemizadeh Date: Mon, 8 Sep 2025 17:06:21 -0400 Subject: [PATCH 2/5] Add tests --- src/cooper/multipliers/multipliers.py | 26 +----- tests/multipliers/conftest.py | 13 +++ .../multipliers/test_explicit_multipliers.py | 85 +++++++------------ tests/optim/torch_optimizers/test_nupi.py | 4 +- 4 files changed, 52 insertions(+), 76 deletions(-) diff --git a/src/cooper/multipliers/multipliers.py b/src/cooper/multipliers/multipliers.py index c398ad37..90fa580a 100644 --- a/src/cooper/multipliers/multipliers.py +++ b/src/cooper/multipliers/multipliers.py @@ -145,7 +145,7 @@ class IndexedMultiplier(ExplicitMultiplier): device: Device for the multiplier. If ``None``, the device is inferred from the ``init`` tensor or the default device. dtype: Data type for the multiplier. Default is ``torch.float32``. - sparse: Whether to use sparse gradients during indexing. Default is True. + sparse_grad: Whether to use sparse gradients. Default is True. """ expects_constraint_features = True @@ -157,15 +157,10 @@ def __init__( device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, *, - sparse: bool = True, + sparse_grad: bool = True, ) -> None: super().__init__(num_constraints, init, device, dtype) - self.sparse = sparse - - if self.weight.dim() == 1: - # To use the forward call in F.embedding, we must reshape the weight to be a - # 2-dim tensor - self.weight.data = self.weight.data.unsqueeze(-1) + self.sparse_grad = sparse_grad def forward(self, indices: torch.Tensor) -> torch.Tensor: """Return the current value of the multiplier at the provided indices. @@ -173,21 +168,8 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: Args: indices: Indices of the multipliers to return. The shape of ``indices`` must be ``(num_indices,)``. - - Raises: - ValueError: If ``indices`` dtype is not ``torch.long``. """ - if indices.dtype != torch.long: - # Not allowing for boolean "indices", which are treated as indices by - # torch.nn.functional.embedding and *not* as masks. - raise ValueError("Indices must be of type torch.long.") - - # TODO(gallego-posada): Document sparse gradients are expected for stateful - # optimizers (having buffers) - multiplier_values = torch.nn.functional.embedding(indices, self.weight, sparse=self.sparse) - - # Flatten multiplier values to 1D since Embedding works with 2D tensors. - return torch.flatten(multiplier_values) + return self.weight.gather(0, indices, sparse_grad=self.sparse_grad) class ImplicitMultiplier(Multiplier): diff --git a/tests/multipliers/conftest.py b/tests/multipliers/conftest.py index 18ad3172..80ca8443 100644 --- a/tests/multipliers/conftest.py +++ b/tests/multipliers/conftest.py @@ -37,6 +37,19 @@ def init_multiplier_tensor(constraint_type, num_constraints, random_seed): return raw_init +@pytest.fixture(params=[True, False]) +def sparse_grad(request): + return request.param + + +@pytest.fixture +def multiplier(multiplier_class, num_constraints, init_multiplier_tensor, device, sparse_grad): + kwargs = {"num_constraints": num_constraints, "init": init_multiplier_tensor, "device": device} + if multiplier_class == cooper.multipliers.IndexedMultiplier: + kwargs["sparse_grad"] = sparse_grad + return multiplier_class(**kwargs) + + @pytest.fixture def all_indices(num_constraints): return torch.arange(num_constraints, dtype=torch.long) diff --git a/tests/multipliers/test_explicit_multipliers.py b/tests/multipliers/test_explicit_multipliers.py index 1d31ea3d..c643d797 100644 --- a/tests/multipliers/test_explicit_multipliers.py +++ b/tests/multipliers/test_explicit_multipliers.py @@ -16,14 +16,12 @@ def evaluate_multiplier(multiplier, all_indices): return multiplier(all_indices) if multiplier.expects_constraint_features else multiplier() -def test_multiplier_initialization_with_init(multiplier_class, init_multiplier_tensor, device): - multiplier = multiplier_class(init=init_multiplier_tensor, device=device) +def test_multiplier_initialization_with_init(multiplier, init_multiplier_tensor, device): assert torch.equal(multiplier.weight.view(-1), init_multiplier_tensor.to(device).view(-1)) assert multiplier.device.type == device.type -def test_multiplier_initialization_with_num_constraints(multiplier_class, num_constraints, device): - multiplier = multiplier_class(num_constraints=num_constraints, device=device) +def test_multiplier_initialization_with_num_constraints(multiplier, num_constraints, device): assert multiplier.weight.numel() == num_constraints assert multiplier.device.type == device.type @@ -43,8 +41,7 @@ def test_multiplier_initialization_with_init_dim(multiplier_class, num_constrain multiplier_class(num_constraints=num_constraints, init=torch.zeros(num_constraints, 1)) -def test_multiplier_repr(multiplier_class, num_constraints): - multiplier = multiplier_class(num_constraints=num_constraints) +def test_multiplier_repr(multiplier, multiplier_class, num_constraints): assert repr(multiplier) == f"{multiplier_class.__name__}(num_constraints={num_constraints})" @@ -53,79 +50,57 @@ def test_multiplier_sanity_check(constraint_type, multiplier_class, init_multipl if constraint_type == cooper.ConstraintType.EQUALITY: pytest.skip("") - multiplier = multiplier_class(init=init_multiplier_tensor.abs().neg()) + neg_multiplier = multiplier_class(init=init_multiplier_tensor.abs().neg()) with pytest.raises(ValueError, match=r"For inequality constraint, all entries in multiplier must be non-negative."): - multiplier.set_constraint_type(cooper.ConstraintType.INEQUALITY) + neg_multiplier.set_constraint_type(cooper.ConstraintType.INEQUALITY) -def test_multiplier_init_and_forward(multiplier_class, init_multiplier_tensor, all_indices): +def test_multiplier_init_and_forward(multiplier, init_multiplier_tensor, all_indices): # Ensure that the multiplier returns the correct value when called - ineq_multiplier = multiplier_class(init=init_multiplier_tensor) - multiplier_values = evaluate_multiplier(ineq_multiplier, all_indices) + multiplier_values = evaluate_multiplier(multiplier, all_indices) target_tensor = init_multiplier_tensor.reshape(multiplier_values.shape) assert torch.allclose(multiplier_values, target_tensor) -def test_indexed_multiplier_forward_invalid_indices(init_multiplier_tensor): - multiplier = cooper.multipliers.IndexedMultiplier(init=init_multiplier_tensor) - indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32) - - with pytest.raises(ValueError, match=r"Indices must be of type torch.long."): - multiplier.forward(indices) - - -def test_equality_post_step_(constraint_type, multiplier_class, init_multiplier_tensor, all_indices): +def test_equality_post_step_(constraint_type, multiplier, init_multiplier_tensor, all_indices): """Post-step for equality multipliers should be a no-op. Check that multiplier values remain unchanged after calling post_step_. """ if constraint_type == cooper.ConstraintType.INEQUALITY: pytest.skip("") - eq_multiplier = multiplier_class(init=init_multiplier_tensor) - eq_multiplier.set_constraint_type(cooper.ConstraintType.EQUALITY) - eq_multiplier.post_step_() - multiplier_values = evaluate_multiplier(eq_multiplier, all_indices) + multiplier.set_constraint_type(constraint_type) + multiplier.post_step_() + multiplier_values = evaluate_multiplier(multiplier, all_indices) target_tensor = init_multiplier_tensor.reshape(multiplier_values.shape) assert torch.allclose(multiplier_values, target_tensor) -def test_ineq_post_step_(constraint_type, multiplier_class, init_multiplier_tensor, all_indices): +def test_ineq_post_step_(constraint_type, multiplier, all_indices): """Ensure that the inequality multipliers remain non-negative after post-step.""" if constraint_type == cooper.ConstraintType.EQUALITY: pytest.skip("") - ineq_multiplier = multiplier_class(init=init_multiplier_tensor) - ineq_multiplier.set_constraint_type(cooper.ConstraintType.INEQUALITY) + multiplier.set_constraint_type(constraint_type) # Overwrite the multiplier to have some *negative* entries and gradients - hard_coded_weight_data = torch.randn_like(ineq_multiplier.weight) - ineq_multiplier.weight.data = hard_coded_weight_data + hard_coded_weight_data = torch.randn_like(multiplier.weight) + multiplier.weight.data = hard_coded_weight_data - hard_coded_gradient_data = torch.randn_like(ineq_multiplier.weight) - ineq_multiplier.weight.grad = hard_coded_gradient_data - if isinstance(ineq_multiplier, cooper.multipliers.IndexedMultiplier): - ineq_multiplier.weight.grad = ineq_multiplier.weight.grad.to_sparse(sparse_dim=1) + hard_coded_gradient_data = torch.randn_like(multiplier.weight) + if isinstance(multiplier, cooper.multipliers.IndexedMultiplier) and multiplier.sparse_grad: + hard_coded_gradient_data = hard_coded_gradient_data.to_sparse(sparse_dim=1) + multiplier.weight.grad = hard_coded_gradient_data - # Post-step should ensure non-negativity. Note that no feasible indices are passed, - # so "feasible" multipliers and their gradients are not reset. - ineq_multiplier.post_step_() + # Post-step should ensure non-negativity + multiplier.post_step_() + multiplier_values = evaluate_multiplier(multiplier, all_indices) - multiplier_values = evaluate_multiplier(ineq_multiplier, all_indices) + target_weight_data = hard_coded_weight_data.relu() + current_grad = multiplier.weight.grad - target_weight_data = hard_coded_weight_data.relu().reshape_as(multiplier_values) - current_grad = ineq_multiplier.weight.grad.to_dense() assert torch.allclose(multiplier_values, target_weight_data) - assert torch.allclose(current_grad, hard_coded_gradient_data) - - # Perform post-step again, this time with feasible indices - ineq_multiplier.post_step_() - - multiplier_values = evaluate_multiplier(ineq_multiplier, all_indices) - - current_grad = ineq_multiplier.weight.grad.to_dense() - # Latest post-step is a no-op - assert torch.allclose(multiplier_values, target_weight_data) - assert torch.allclose(current_grad, hard_coded_gradient_data) + assert torch.allclose(current_grad.to_dense(), hard_coded_gradient_data.to_dense()) def check_save_load_state_dict(multiplier, explicit_multiplier_class, num_constraints, random_seed): @@ -144,7 +119,13 @@ def check_save_load_state_dict(multiplier, explicit_multiplier_class, num_constr assert torch.equal(multiplier.weight, new_multiplier.weight) -def test_save_load_multiplier(multiplier_class, init_multiplier_tensor, num_constraints, random_seed): +def test_save_load_multiplier(multiplier, multiplier_class, num_constraints, random_seed): """Test that the state_dict of a multiplier can be saved and loaded correctly.""" - multiplier = multiplier_class(init=init_multiplier_tensor) check_save_load_state_dict(multiplier, multiplier_class, num_constraints, random_seed) + + +def test_multiplier_grad(multiplier, all_indices): + evaluate_multiplier(multiplier, all_indices).sum().backward() + assert multiplier.weight.grad.is_sparse == ( + isinstance(multiplier, cooper.multipliers.IndexedMultiplier) and multiplier.sparse_grad + ) diff --git a/tests/optim/torch_optimizers/test_nupi.py b/tests/optim/torch_optimizers/test_nupi.py index b7e66e5d..a31f1f6a 100644 --- a/tests/optim/torch_optimizers/test_nupi.py +++ b/tests/optim/torch_optimizers/test_nupi.py @@ -158,7 +158,7 @@ def loss_fn(indices): def compute_analytic_gradient(indices): # For the quadratic loss, the gradient is simply the current value of p. - return multiplier_module(indices).reshape(-1, 1).clone().detach() + return multiplier_module(indices).clone().detach() def recursive_nuPI_direction(error, previous_xi): return (Ki + (1 - ema_nu) * Kp) * error - (1 - ema_nu) * Kp * previous_xi @@ -275,7 +275,7 @@ def loss_fn(indices): def compute_analytic_gradient(indices): # For the quadratic loss, the gradient is simply the current value of p. - return multiplier_module(indices).reshape(-1, 1).clone().detach() + return multiplier_module(indices).clone().detach() optimizer = nuPI( multiplier_module.parameters(), From 6de711354110e95b7a3fb90fad27a172618096fe Mon Sep 17 00:00:00 2001 From: Meraj Hashemizadeh Date: Mon, 8 Sep 2025 19:07:44 -0400 Subject: [PATCH 3/5] Fix tests --- src/cooper/optim/torch_optimizers/nupi_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cooper/optim/torch_optimizers/nupi_optimizer.py b/src/cooper/optim/torch_optimizers/nupi_optimizer.py index c759fb40..54acc047 100644 --- a/src/cooper/optim/torch_optimizers/nupi_optimizer.py +++ b/src/cooper/optim/torch_optimizers/nupi_optimizer.py @@ -301,7 +301,7 @@ def _sparse_nupi_zero_init( nupi_update_values.add_(detached_error_values.mul(et_coef_values)) if xit_m1_coef_values.ne(0).any(): - xi_values = state["xi"].sparse_mask(error)._values() + xi_values = state["xi"][tuple(error_indices)] nupi_update_values.sub_(xi_values.mul(xit_m1_coef)) nupi_update = torch.sparse_coo_tensor(error_indices, nupi_update_values, size=param.shape) @@ -404,9 +404,9 @@ def _sparse_nupi_sgd_init( nupi_update_values.add_(detached_error_values.mul(filtered_Ki_values)) if uses_kp_term: - previous_xi_values = state["xi"].sparse_mask(error)._values() + previous_xi_values = state["xi"][tuple(error_indices)] proportional_term_contribution = torch.where( - state["needs_error_initialization_mask"].sparse_mask(error)._values(), + state["needs_error_initialization_mask"][tuple(error_indices)], torch.zeros_like(detached_error_values), # If state has not been initialized, xi_0 = 0 (1 - ema_nu) * (detached_error_values - previous_xi_values), # Else, we use recursive update ) From 99bd7daa9caaa3e07547261b63f2b95befa22332 Mon Sep 17 00:00:00 2001 From: Meraj Hashemizadeh Date: Mon, 15 Sep 2025 02:07:48 -0400 Subject: [PATCH 4/5] Add note about stateful optimizers --- src/cooper/multipliers/multipliers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/cooper/multipliers/multipliers.py b/src/cooper/multipliers/multipliers.py index 90fa580a..4517bd15 100644 --- a/src/cooper/multipliers/multipliers.py +++ b/src/cooper/multipliers/multipliers.py @@ -145,7 +145,15 @@ class IndexedMultiplier(ExplicitMultiplier): device: Device for the multiplier. If ``None``, the device is inferred from the ``init`` tensor or the default device. dtype: Data type for the multiplier. Default is ``torch.float32``. - sparse_grad: Whether to use sparse gradients. Default is True. + sparse_grad: Whether to use sparse gradients. Default is ``True``. When set to + ``False`` with stateful optimizers (e.g., Adam), optimizer states will be + updated for all parameters, assuming zero gradients for non-sampled indices. + This may lead to incorrect optimization behavior as these values should not + be updated at all. + + Note: + The default value of ``sparse_grad=True`` is recommended for stateful optimizers. + Set ``sparse_grad=False`` only when necessary (e.g., when using DDP) and with caution. """ expects_constraint_features = True From 4d76e776b6768d0651add9d151de3f62bae6090f Mon Sep 17 00:00:00 2001 From: Meraj Hashemizadeh Date: Wed, 17 Sep 2025 16:04:34 -0400 Subject: [PATCH 5/5] Add dimension check for nuPI optimizer with TODO for multidimensional support --- src/cooper/optim/torch_optimizers/nupi_optimizer.py | 4 ++++ tests/optim/torch_optimizers/test_nupi.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/src/cooper/optim/torch_optimizers/nupi_optimizer.py b/src/cooper/optim/torch_optimizers/nupi_optimizer.py index 54acc047..c8948404 100644 --- a/src/cooper/optim/torch_optimizers/nupi_optimizer.py +++ b/src/cooper/optim/torch_optimizers/nupi_optimizer.py @@ -197,6 +197,10 @@ def step(self, closure: Optional[Callable] = None) -> Optional[float]: if p.grad is None: continue + if p.grad.ndim != 1: + # TODO(juan43ramirez): Implement support for multidimensional parameters + raise NotImplementedError("nuPI optimizer only supports 1D parameters.") + update_function = self.disambiguate_update_function(p.grad.is_sparse, group["init_type"]) update_function( param=p, diff --git a/tests/optim/torch_optimizers/test_nupi.py b/tests/optim/torch_optimizers/test_nupi.py index a31f1f6a..1f0c0772 100644 --- a/tests/optim/torch_optimizers/test_nupi.py +++ b/tests/optim/torch_optimizers/test_nupi.py @@ -352,3 +352,11 @@ def do_optimizer_step(indices): # Check state entries that have not been updated yet unseen_indices = torch.tensor([4, 8, 9], device=device) assert torch.allclose(buffer[unseen_indices], torch.zeros_like(buffer[unseen_indices])) + + +def test_nupi_multi_dimensional_raises(): + param = torch.ones(2, 3, requires_grad=True) + param.sum().backward() + optimizer = nuPI([param], lr=0.01) + with pytest.raises(NotImplementedError, match="nuPI optimizer only supports 1D parameters"): + optimizer.step()