Skip to content

Commit a9b315f

Browse files
authored
Add gtest coverage for core problem assembly (#39)
* Add gtest coverage for core problem assembly * Run tests in CI workflow
1 parent cfd8aa3 commit a9b315f

3 files changed

Lines changed: 204 additions & 4 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@ jobs:
1414
run: sudo ./scripts/setup_dependencies.sh
1515
- name: Build project
1616
run: ./scripts/build.sh
17+
- name: Run tests
18+
run: cd build/release && ctest --output-on-failure
1719
- name: Run examples
1820
run: ./scripts/run.sh

CMakeLists.txt

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
66

77
option(BUILD_EXAMPLES "Build example executables" ON)
88
option(BUILD_SHARED_LIBS "Build as shared library" OFF)
9+
include(CTest)
910

1011
# Dependencies
1112
find_package(Eigen3 3.4 REQUIRED NO_MODULE)
@@ -80,8 +81,26 @@ configure_package_config_file(
8081
INSTALL_DESTINATION lib/cmake/MultiAgentSolver
8182
)
8283

83-
install(FILES
84-
"${CMAKE_CURRENT_BINARY_DIR}/MultiAgentSolverConfig.cmake"
85-
"${CMAKE_CURRENT_BINARY_DIR}/MultiAgentSolverConfigVersion.cmake"
86-
DESTINATION lib/cmake/MultiAgentSolver
84+
install(FILES
85+
"${CMAKE_CURRENT_BINARY_DIR}/MultiAgentSolverConfig.cmake"
86+
"${CMAKE_CURRENT_BINARY_DIR}/MultiAgentSolverConfigVersion.cmake"
87+
DESTINATION lib/cmake/MultiAgentSolver
8788
)
89+
90+
if(BUILD_TESTING)
91+
include(FetchContent)
92+
FetchContent_Declare(
93+
googletest
94+
URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz
95+
)
96+
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
97+
FetchContent_MakeAvailable(googletest)
98+
99+
enable_testing()
100+
file(GLOB TEST_SOURCES CONFIGURE_DEPENDS tests/*.cpp)
101+
add_executable(multi_agent_solver_tests ${TEST_SOURCES})
102+
target_link_libraries(multi_agent_solver_tests PRIVATE MultiAgentSolver GTest::gtest_main)
103+
104+
include(GoogleTest)
105+
gtest_discover_tests(multi_agent_solver_tests)
106+
endif()

tests/ocp_tests.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "multi_agent_solver/agent.hpp"
4+
#include "multi_agent_solver/finite_differences.hpp"
5+
#include "multi_agent_solver/multi_agent_problem.hpp"
6+
#include "multi_agent_solver/ocp.hpp"
7+
8+
namespace mas
9+
{
10+
namespace
11+
{
12+
13+
MotionModel
14+
create_integrator()
15+
{
16+
return []( const State& state, const Control& control ) { return control + state * 0.0; };
17+
}
18+
19+
} // namespace
20+
21+
TEST( OCPTest, InitializeProblemSetsDefaultsAndBestCost )
22+
{
23+
OCP ocp;
24+
ocp.state_dim = 1;
25+
ocp.control_dim = 1;
26+
ocp.horizon_steps = 3;
27+
ocp.dt = 0.1;
28+
ocp.initial_state = State::Zero( ocp.state_dim );
29+
ocp.dynamics = create_integrator();
30+
ocp.stage_cost = []( const State& x, const Control& u, size_t ) {
31+
return x.squaredNorm() + u.squaredNorm();
32+
};
33+
ocp.terminal_cost = []( const State& x ) { return x.squaredNorm(); };
34+
35+
ocp.initialize_problem();
36+
37+
EXPECT_EQ( ocp.best_states.rows(), ocp.state_dim );
38+
EXPECT_EQ( ocp.best_states.cols(), ocp.horizon_steps + 1 );
39+
EXPECT_EQ( ocp.best_controls.rows(), ocp.control_dim );
40+
EXPECT_EQ( ocp.best_controls.cols(), ocp.horizon_steps );
41+
EXPECT_DOUBLE_EQ( ocp.best_cost, 0.0 );
42+
43+
ASSERT_TRUE( static_cast<bool>( ocp.cost_state_gradient ) );
44+
ASSERT_TRUE( static_cast<bool>( ocp.cost_control_gradient ) );
45+
46+
auto state_grad = ocp.cost_state_gradient( ocp.stage_cost, ocp.best_states.col( 0 ),
47+
ocp.best_controls.col( 0 ), 0 );
48+
auto control_grad = ocp.cost_control_gradient( ocp.stage_cost, ocp.best_states.col( 0 ),
49+
ocp.best_controls.col( 0 ), 0 );
50+
51+
EXPECT_EQ( state_grad.size(), ocp.state_dim );
52+
EXPECT_EQ( control_grad.size(), ocp.control_dim );
53+
EXPECT_TRUE( ocp.verify_problem() );
54+
}
55+
56+
TEST( OCPTest, UpdateInitialWithBestCopiesTrajectories )
57+
{
58+
OCP ocp;
59+
ocp.state_dim = 2;
60+
ocp.control_dim = 2;
61+
ocp.horizon_steps = 2;
62+
ocp.dt = 1.0;
63+
ocp.initial_state = State::Zero( ocp.state_dim );
64+
ocp.dynamics = create_integrator();
65+
ocp.initialize_problem();
66+
67+
ocp.best_controls = ControlTrajectory::Ones( ocp.control_dim, ocp.horizon_steps );
68+
ocp.best_states = StateTrajectory::Ones( ocp.state_dim, ocp.horizon_steps + 1 );
69+
70+
ocp.update_initial_with_best();
71+
72+
EXPECT_TRUE( ocp.initial_controls.isApprox( ocp.best_controls ) );
73+
EXPECT_TRUE( ocp.initial_states.isApprox( ocp.best_states ) );
74+
}
75+
76+
TEST( MultiAgentProblemTest, BuildGlobalProblemMergesAgents )
77+
{
78+
auto ocp_a = std::make_shared<OCP>();
79+
ocp_a->state_dim = 2;
80+
ocp_a->control_dim = 1;
81+
ocp_a->horizon_steps = 2;
82+
ocp_a->dt = 0.5;
83+
ocp_a->initial_state = State::Ones( ocp_a->state_dim );
84+
ocp_a->dynamics = []( const State& x, const Control& u ) {
85+
return x + u.replicate( x.size(), 1 );
86+
};
87+
ocp_a->stage_cost = []( const State& x, const Control& u, size_t ) { return x.sum() + u.sum(); };
88+
ocp_a->terminal_cost = []( const State& x ) { return 2.0 * x.sum(); };
89+
ocp_a->input_lower_bounds = Control::Constant( ocp_a->control_dim, -1.0 );
90+
ocp_a->input_upper_bounds = Control::Constant( ocp_a->control_dim, 1.0 );
91+
ocp_a->initialize_problem();
92+
93+
auto ocp_b = std::make_shared<OCP>();
94+
ocp_b->state_dim = 1;
95+
ocp_b->control_dim = 2;
96+
ocp_b->horizon_steps = 2;
97+
ocp_b->dt = 0.5;
98+
ocp_b->initial_state = State::Constant( ocp_b->state_dim, 3.0 );
99+
ocp_b->dynamics = []( const State& x, const Control& u ) {
100+
return x + Control::Constant( x.size(), 2.0 * u.sum() );
101+
};
102+
ocp_b->stage_cost = []( const State& x, const Control& u, size_t ) {
103+
return 2.0 * x.sum() + 3.0 * u.sum();
104+
};
105+
ocp_b->terminal_cost = []( const State& x ) { return x.sum(); };
106+
ocp_b->input_lower_bounds = Control::Constant( ocp_b->control_dim, -2.0 );
107+
ocp_b->input_upper_bounds = Control::Constant( ocp_b->control_dim, 2.0 );
108+
ocp_b->initialize_problem();
109+
110+
MultiAgentProblem problem;
111+
problem.add_agent( std::make_shared<Agent>( 2, ocp_b ) );
112+
problem.add_agent( std::make_shared<Agent>( 1, ocp_a ) );
113+
problem.compute_offsets();
114+
115+
ASSERT_EQ( problem.blocks.size(), 2 );
116+
EXPECT_EQ( problem.blocks.front().agent_id, 1 );
117+
EXPECT_EQ( problem.blocks.back().agent_id, 2 );
118+
EXPECT_EQ( problem.blocks.front().state_offset, 0 );
119+
EXPECT_EQ( problem.blocks.front().control_offset, 0 );
120+
EXPECT_EQ( problem.blocks.back().state_offset, ocp_a->state_dim );
121+
EXPECT_EQ( problem.blocks.back().control_offset, ocp_a->control_dim );
122+
123+
OCP global = problem.build_global_ocp();
124+
125+
EXPECT_EQ( global.state_dim, ocp_a->state_dim + ocp_b->state_dim );
126+
EXPECT_EQ( global.control_dim, ocp_a->control_dim + ocp_b->control_dim );
127+
EXPECT_EQ( global.horizon_steps, ocp_a->horizon_steps );
128+
EXPECT_DOUBLE_EQ( global.dt, ocp_a->dt );
129+
130+
ASSERT_TRUE( global.input_lower_bounds.has_value() );
131+
ASSERT_TRUE( global.input_upper_bounds.has_value() );
132+
EXPECT_DOUBLE_EQ( ( *global.input_lower_bounds )( 0 ), -1.0 );
133+
EXPECT_DOUBLE_EQ( ( *global.input_lower_bounds )( 1 ), -2.0 );
134+
EXPECT_DOUBLE_EQ( ( *global.input_lower_bounds )( 2 ), -2.0 );
135+
136+
State expected_initial( global.state_dim );
137+
expected_initial << 1.0, 1.0, 3.0;
138+
EXPECT_TRUE( global.initial_state.isApprox( expected_initial ) );
139+
140+
State state = State::LinSpaced( global.state_dim, 1.0, 3.0 );
141+
Control control = Control::LinSpaced( global.control_dim, -1.0, 1.0 );
142+
StateDerivative derivative = global.dynamics( state, control );
143+
144+
EXPECT_EQ( derivative.size(), global.state_dim );
145+
EXPECT_DOUBLE_EQ( derivative( 0 ), state( 0 ) + control( 0 ) );
146+
EXPECT_DOUBLE_EQ( derivative( 1 ), state( 1 ) + control( 0 ) );
147+
EXPECT_DOUBLE_EQ( derivative( 2 ), state( 2 ) + 2.0 * control.tail( 2 ).sum() );
148+
149+
double expected_stage_cost = ( state.segment( 0, 2 ).sum() + control.segment( 0, 1 ).sum() )
150+
+ ( 2.0 * state.tail( 1 ).sum()
151+
+ 3.0 * control.tail( 2 ).sum() );
152+
EXPECT_DOUBLE_EQ( global.stage_cost( state, control, 0 ), expected_stage_cost );
153+
EXPECT_DOUBLE_EQ( global.terminal_cost( state ), 2.0 * state.segment( 0, 2 ).sum() + state.tail( 1 ).sum() );
154+
}
155+
156+
TEST( FiniteDifferencesTest, GradientMatchesAnalyticalForQuadraticObjective )
157+
{
158+
MotionModel dynamics = []( const State&, const Control& u ) { return u; };
159+
ObjectiveFunction objective = []( const StateTrajectory& states, const ControlTrajectory& controls ) {
160+
double state_sum = states.array().square().sum();
161+
double control_sum = controls.array().square().sum();
162+
return state_sum + control_sum;
163+
};
164+
165+
State initial_state = State::Zero( 1 );
166+
ControlTrajectory controls( 1, 2 );
167+
controls << 1.0, -1.0;
168+
169+
ControlGradient gradient
170+
= finite_differences_gradient( initial_state, controls, dynamics, objective, 1.0 );
171+
172+
Control expected( 2 );
173+
expected << 4.0, -2.0;
174+
175+
EXPECT_NEAR( gradient( 0, 0 ), expected( 0 ), 1e-3 );
176+
EXPECT_NEAR( gradient( 0, 1 ), expected( 1 ), 1e-3 );
177+
}
178+
179+
} // namespace mas

0 commit comments

Comments
 (0)