Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions include/multi_agent_solver/constraint_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace mas

// Helper function to compute the augmented cost
inline double
compute_augmented_cost( const OCP& problem, const ConstraintViolations& equality_multipliers,
const ConstraintViolations& inequality_multipliers, double penalty_parameter, const StateTrajectory& states,
compute_augmented_cost( const OCP& problem, const ConstraintViolationsTrajectory& equality_multipliers,
const ConstraintViolationsTrajectory& inequality_multipliers, double penalty_parameter, const StateTrajectory& states,
const ControlTrajectory& controls )
{
double cost = problem.objective_function( states, controls );
Expand All @@ -26,14 +26,22 @@ compute_augmented_cost( const OCP& problem, const ConstraintViolations& equality
if( problem.equality_constraints )
{
ConstraintViolations eq_residuals = problem.equality_constraints( states.col( t ), controls.col( t ) );
cost += equality_multipliers.dot( eq_residuals ) + 0.5 * penalty_parameter * eq_residuals.squaredNorm();
if (equality_multipliers.cols() > t) {
cost += equality_multipliers.col(t).dot( eq_residuals ) + 0.5 * penalty_parameter * eq_residuals.squaredNorm();
}
}

if( problem.inequality_constraints )
{
ConstraintViolations ineq_residuals = problem.inequality_constraints( states.col( t ), controls.col( t ) );
ConstraintViolations slack = ( ineq_residuals.array() > 0 ).select( ineq_residuals, 0 );
cost += inequality_multipliers.dot( slack ) + 0.5 * penalty_parameter * slack.squaredNorm();
if (inequality_multipliers.cols() > t) {
// PHR augmented Lagrangian term for inequalities:
// (1 / 2rho) * ( max(0, lambda + rho * g)^2 - lambda^2 )
const auto& lambda = inequality_multipliers.col(t);
Eigen::VectorXd combined = lambda + penalty_parameter * ineq_residuals;
Eigen::VectorXd combined_plus = combined.cwiseMax(0.0);
cost += (0.5 / penalty_parameter) * (combined_plus.squaredNorm() - lambda.squaredNorm());
}
}
}

Expand All @@ -43,21 +51,26 @@ compute_augmented_cost( const OCP& problem, const ConstraintViolations& equality
// Helper function to update Lagrange multipliers
inline void
update_lagrange_multipliers( const OCP& problem, const StateTrajectory& states, const ControlTrajectory& controls,
ConstraintViolations& equality_multipliers, ConstraintViolations& inequality_multipliers,
ConstraintViolationsTrajectory& equality_multipliers, ConstraintViolationsTrajectory& inequality_multipliers,
double penalty_parameter )
{
for( int t = 0; t < controls.cols(); ++t )
{
if( problem.equality_constraints )
{
ConstraintViolations eq_residuals = problem.equality_constraints( states.col( t ), controls.col( t ) );
equality_multipliers += penalty_parameter * eq_residuals;
if (equality_multipliers.cols() > t) {
equality_multipliers.col(t) += penalty_parameter * eq_residuals;
}
}

if( problem.inequality_constraints )
{
ConstraintViolations ineq_residuals = problem.inequality_constraints( states.col( t ), controls.col( t ) );
inequality_multipliers += penalty_parameter * ( ineq_residuals.array() > 0 ).select( ineq_residuals, 0 );
if (inequality_multipliers.cols() > t) {
// Update rule: lambda_next = max(0, lambda + rho * g)
inequality_multipliers.col(t) = (inequality_multipliers.col(t) + penalty_parameter * ineq_residuals).cwiseMax(0.0);
}
}
}
}
Expand Down
32 changes: 20 additions & 12 deletions include/multi_agent_solver/solvers/cgd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "multi_agent_solver/integrator.hpp"
#include "multi_agent_solver/line_search.hpp"
#include "multi_agent_solver/ocp.hpp"
#include "multi_agent_solver/solvers/solver.hpp"
#include "multi_agent_solver/types.hpp"

namespace mas
Expand Down Expand Up @@ -71,10 +70,15 @@ class CGD
break;
}

// Define augmented cost for this iteration
auto augmented_objective = [&]( const StateTrajectory& s, const ControlTrajectory& c ) {
return compute_augmented_cost( problem, eq_multipliers, ineq_multipliers, penalty_param, s, c );
};

const ControlGradient gradients = finite_differences_gradient( problem.initial_state, controls, problem.dynamics,
problem.objective_function, problem.dt );
augmented_objective, problem.dt );

const double step_size = armijo_line_search( problem.initial_state, controls, gradients, problem.dynamics, problem.objective_function,
const double step_size = armijo_line_search( problem.initial_state, controls, gradients, problem.dynamics, augmented_objective,
problem.dt, {} );

ControlTrajectory trial_controls = controls - step_size * gradients;
Expand All @@ -100,6 +104,7 @@ class CGD
update_lagrange_multipliers( problem, state_trajectory, controls, eq_multipliers, ineq_multipliers, penalty_param );

increase_penalty_parameter( penalty_param, problem, state_trajectory, controls, tolerance );
if (penalty_param > 1e6) penalty_param = 1e6; // Cap penalty

if( std::abs( old_cost - trial_cost ) < tolerance && debug )
{
Expand All @@ -114,24 +119,27 @@ class CGD
void
resize_multipliers( const OCP& problem )
{
int horizon = problem.horizon_steps;
Control default_control = Control::Zero( problem.control_dim );

if( problem.equality_constraints )
{
const auto m = problem.equality_constraints( problem.initial_state, {} ).size();
eq_multipliers.setZero( m );
const auto m = problem.equality_constraints( problem.initial_state, default_control ).size();
eq_multipliers.setZero( m, horizon );
}
else
{
eq_multipliers.resize( 0 );
eq_multipliers.resize( 0, 0 );
}

if( problem.inequality_constraints )
{
const auto p = problem.inequality_constraints( problem.initial_state, {} ).size();
ineq_multipliers.setZero( p );
const auto p = problem.inequality_constraints( problem.initial_state, default_control ).size();
ineq_multipliers.setZero( p, horizon );
}
else
{
ineq_multipliers.resize( 0 );
ineq_multipliers.resize( 0, 0 );
}
}

Expand All @@ -140,9 +148,9 @@ class CGD
double max_ms;
bool debug = false;

ConstraintViolations eq_multipliers;
ConstraintViolations ineq_multipliers;
double penalty_param;
ConstraintViolationsTrajectory eq_multipliers;
ConstraintViolationsTrajectory ineq_multipliers;
double penalty_param;
};

} // namespace mas
1 change: 1 addition & 0 deletions include/multi_agent_solver/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using TerminalCostFunction = std::function<double( const State& )>;

// Constraints
using ConstraintViolations = Eigen::VectorXd;
using ConstraintViolationsTrajectory = Eigen::MatrixXd;
using ConstraintsFunction = std::function<ConstraintViolations( const State&, const Control& )>;
using ConstraintsJacobian = Eigen::MatrixXd;
using ConstraintsJacobianFunction = std::function<ConstraintsJacobian( const State&, const Control& )>;
Expand Down