|
| 1 | +#include <gtest/gtest.h> |
| 2 | + |
| 3 | +#include "multi_agent_solver/agent.hpp" |
| 4 | +#include "multi_agent_solver/finite_differences.hpp" |
| 5 | +#include "multi_agent_solver/multi_agent_problem.hpp" |
| 6 | +#include "multi_agent_solver/ocp.hpp" |
| 7 | + |
| 8 | +namespace mas |
| 9 | +{ |
| 10 | +namespace |
| 11 | +{ |
| 12 | + |
| 13 | +MotionModel |
| 14 | +create_integrator() |
| 15 | +{ |
| 16 | + return []( const State& state, const Control& control ) { return control + state * 0.0; }; |
| 17 | +} |
| 18 | + |
| 19 | +} // namespace |
| 20 | + |
| 21 | +TEST( OCPTest, InitializeProblemSetsDefaultsAndBestCost ) |
| 22 | +{ |
| 23 | + OCP ocp; |
| 24 | + ocp.state_dim = 1; |
| 25 | + ocp.control_dim = 1; |
| 26 | + ocp.horizon_steps = 3; |
| 27 | + ocp.dt = 0.1; |
| 28 | + ocp.initial_state = State::Zero( ocp.state_dim ); |
| 29 | + ocp.dynamics = create_integrator(); |
| 30 | + ocp.stage_cost = []( const State& x, const Control& u, size_t ) { |
| 31 | + return x.squaredNorm() + u.squaredNorm(); |
| 32 | + }; |
| 33 | + ocp.terminal_cost = []( const State& x ) { return x.squaredNorm(); }; |
| 34 | + |
| 35 | + ocp.initialize_problem(); |
| 36 | + |
| 37 | + EXPECT_EQ( ocp.best_states.rows(), ocp.state_dim ); |
| 38 | + EXPECT_EQ( ocp.best_states.cols(), ocp.horizon_steps + 1 ); |
| 39 | + EXPECT_EQ( ocp.best_controls.rows(), ocp.control_dim ); |
| 40 | + EXPECT_EQ( ocp.best_controls.cols(), ocp.horizon_steps ); |
| 41 | + EXPECT_DOUBLE_EQ( ocp.best_cost, 0.0 ); |
| 42 | + |
| 43 | + ASSERT_TRUE( static_cast<bool>( ocp.cost_state_gradient ) ); |
| 44 | + ASSERT_TRUE( static_cast<bool>( ocp.cost_control_gradient ) ); |
| 45 | + |
| 46 | + auto state_grad = ocp.cost_state_gradient( ocp.stage_cost, ocp.best_states.col( 0 ), |
| 47 | + ocp.best_controls.col( 0 ), 0 ); |
| 48 | + auto control_grad = ocp.cost_control_gradient( ocp.stage_cost, ocp.best_states.col( 0 ), |
| 49 | + ocp.best_controls.col( 0 ), 0 ); |
| 50 | + |
| 51 | + EXPECT_EQ( state_grad.size(), ocp.state_dim ); |
| 52 | + EXPECT_EQ( control_grad.size(), ocp.control_dim ); |
| 53 | + EXPECT_TRUE( ocp.verify_problem() ); |
| 54 | +} |
| 55 | + |
| 56 | +TEST( OCPTest, UpdateInitialWithBestCopiesTrajectories ) |
| 57 | +{ |
| 58 | + OCP ocp; |
| 59 | + ocp.state_dim = 2; |
| 60 | + ocp.control_dim = 2; |
| 61 | + ocp.horizon_steps = 2; |
| 62 | + ocp.dt = 1.0; |
| 63 | + ocp.initial_state = State::Zero( ocp.state_dim ); |
| 64 | + ocp.dynamics = create_integrator(); |
| 65 | + ocp.initialize_problem(); |
| 66 | + |
| 67 | + ocp.best_controls = ControlTrajectory::Ones( ocp.control_dim, ocp.horizon_steps ); |
| 68 | + ocp.best_states = StateTrajectory::Ones( ocp.state_dim, ocp.horizon_steps + 1 ); |
| 69 | + |
| 70 | + ocp.update_initial_with_best(); |
| 71 | + |
| 72 | + EXPECT_TRUE( ocp.initial_controls.isApprox( ocp.best_controls ) ); |
| 73 | + EXPECT_TRUE( ocp.initial_states.isApprox( ocp.best_states ) ); |
| 74 | +} |
| 75 | + |
| 76 | +TEST( MultiAgentProblemTest, BuildGlobalProblemMergesAgents ) |
| 77 | +{ |
| 78 | + auto ocp_a = std::make_shared<OCP>(); |
| 79 | + ocp_a->state_dim = 2; |
| 80 | + ocp_a->control_dim = 1; |
| 81 | + ocp_a->horizon_steps = 2; |
| 82 | + ocp_a->dt = 0.5; |
| 83 | + ocp_a->initial_state = State::Ones( ocp_a->state_dim ); |
| 84 | + ocp_a->dynamics = []( const State& x, const Control& u ) { |
| 85 | + return x + u.replicate( x.size(), 1 ); |
| 86 | + }; |
| 87 | + ocp_a->stage_cost = []( const State& x, const Control& u, size_t ) { return x.sum() + u.sum(); }; |
| 88 | + ocp_a->terminal_cost = []( const State& x ) { return 2.0 * x.sum(); }; |
| 89 | + ocp_a->input_lower_bounds = Control::Constant( ocp_a->control_dim, -1.0 ); |
| 90 | + ocp_a->input_upper_bounds = Control::Constant( ocp_a->control_dim, 1.0 ); |
| 91 | + ocp_a->initialize_problem(); |
| 92 | + |
| 93 | + auto ocp_b = std::make_shared<OCP>(); |
| 94 | + ocp_b->state_dim = 1; |
| 95 | + ocp_b->control_dim = 2; |
| 96 | + ocp_b->horizon_steps = 2; |
| 97 | + ocp_b->dt = 0.5; |
| 98 | + ocp_b->initial_state = State::Constant( ocp_b->state_dim, 3.0 ); |
| 99 | + ocp_b->dynamics = []( const State& x, const Control& u ) { |
| 100 | + return x + Control::Constant( x.size(), 2.0 * u.sum() ); |
| 101 | + }; |
| 102 | + ocp_b->stage_cost = []( const State& x, const Control& u, size_t ) { |
| 103 | + return 2.0 * x.sum() + 3.0 * u.sum(); |
| 104 | + }; |
| 105 | + ocp_b->terminal_cost = []( const State& x ) { return x.sum(); }; |
| 106 | + ocp_b->input_lower_bounds = Control::Constant( ocp_b->control_dim, -2.0 ); |
| 107 | + ocp_b->input_upper_bounds = Control::Constant( ocp_b->control_dim, 2.0 ); |
| 108 | + ocp_b->initialize_problem(); |
| 109 | + |
| 110 | + MultiAgentProblem problem; |
| 111 | + problem.add_agent( std::make_shared<Agent>( 2, ocp_b ) ); |
| 112 | + problem.add_agent( std::make_shared<Agent>( 1, ocp_a ) ); |
| 113 | + problem.compute_offsets(); |
| 114 | + |
| 115 | + ASSERT_EQ( problem.blocks.size(), 2 ); |
| 116 | + EXPECT_EQ( problem.blocks.front().agent_id, 1 ); |
| 117 | + EXPECT_EQ( problem.blocks.back().agent_id, 2 ); |
| 118 | + EXPECT_EQ( problem.blocks.front().state_offset, 0 ); |
| 119 | + EXPECT_EQ( problem.blocks.front().control_offset, 0 ); |
| 120 | + EXPECT_EQ( problem.blocks.back().state_offset, ocp_a->state_dim ); |
| 121 | + EXPECT_EQ( problem.blocks.back().control_offset, ocp_a->control_dim ); |
| 122 | + |
| 123 | + OCP global = problem.build_global_ocp(); |
| 124 | + |
| 125 | + EXPECT_EQ( global.state_dim, ocp_a->state_dim + ocp_b->state_dim ); |
| 126 | + EXPECT_EQ( global.control_dim, ocp_a->control_dim + ocp_b->control_dim ); |
| 127 | + EXPECT_EQ( global.horizon_steps, ocp_a->horizon_steps ); |
| 128 | + EXPECT_DOUBLE_EQ( global.dt, ocp_a->dt ); |
| 129 | + |
| 130 | + ASSERT_TRUE( global.input_lower_bounds.has_value() ); |
| 131 | + ASSERT_TRUE( global.input_upper_bounds.has_value() ); |
| 132 | + EXPECT_DOUBLE_EQ( ( *global.input_lower_bounds )( 0 ), -1.0 ); |
| 133 | + EXPECT_DOUBLE_EQ( ( *global.input_lower_bounds )( 1 ), -2.0 ); |
| 134 | + EXPECT_DOUBLE_EQ( ( *global.input_lower_bounds )( 2 ), -2.0 ); |
| 135 | + |
| 136 | + State expected_initial( global.state_dim ); |
| 137 | + expected_initial << 1.0, 1.0, 3.0; |
| 138 | + EXPECT_TRUE( global.initial_state.isApprox( expected_initial ) ); |
| 139 | + |
| 140 | + State state = State::LinSpaced( global.state_dim, 1.0, 3.0 ); |
| 141 | + Control control = Control::LinSpaced( global.control_dim, -1.0, 1.0 ); |
| 142 | + StateDerivative derivative = global.dynamics( state, control ); |
| 143 | + |
| 144 | + EXPECT_EQ( derivative.size(), global.state_dim ); |
| 145 | + EXPECT_DOUBLE_EQ( derivative( 0 ), state( 0 ) + control( 0 ) ); |
| 146 | + EXPECT_DOUBLE_EQ( derivative( 1 ), state( 1 ) + control( 0 ) ); |
| 147 | + EXPECT_DOUBLE_EQ( derivative( 2 ), state( 2 ) + 2.0 * control.tail( 2 ).sum() ); |
| 148 | + |
| 149 | + double expected_stage_cost = ( state.segment( 0, 2 ).sum() + control.segment( 0, 1 ).sum() ) |
| 150 | + + ( 2.0 * state.tail( 1 ).sum() |
| 151 | + + 3.0 * control.tail( 2 ).sum() ); |
| 152 | + EXPECT_DOUBLE_EQ( global.stage_cost( state, control, 0 ), expected_stage_cost ); |
| 153 | + EXPECT_DOUBLE_EQ( global.terminal_cost( state ), 2.0 * state.segment( 0, 2 ).sum() + state.tail( 1 ).sum() ); |
| 154 | +} |
| 155 | + |
| 156 | +TEST( FiniteDifferencesTest, GradientMatchesAnalyticalForQuadraticObjective ) |
| 157 | +{ |
| 158 | + MotionModel dynamics = []( const State&, const Control& u ) { return u; }; |
| 159 | + ObjectiveFunction objective = []( const StateTrajectory& states, const ControlTrajectory& controls ) { |
| 160 | + double state_sum = states.array().square().sum(); |
| 161 | + double control_sum = controls.array().square().sum(); |
| 162 | + return state_sum + control_sum; |
| 163 | + }; |
| 164 | + |
| 165 | + State initial_state = State::Zero( 1 ); |
| 166 | + ControlTrajectory controls( 1, 2 ); |
| 167 | + controls << 1.0, -1.0; |
| 168 | + |
| 169 | + ControlGradient gradient |
| 170 | + = finite_differences_gradient( initial_state, controls, dynamics, objective, 1.0 ); |
| 171 | + |
| 172 | + Control expected( 2 ); |
| 173 | + expected << 4.0, -2.0; |
| 174 | + |
| 175 | + EXPECT_NEAR( gradient( 0, 0 ), expected( 0 ), 1e-3 ); |
| 176 | + EXPECT_NEAR( gradient( 0, 1 ), expected( 1 ), 1e-3 ); |
| 177 | +} |
| 178 | + |
| 179 | +} // namespace mas |
0 commit comments