|
1 | 1 | #include "drivenet.h" |
| 2 | +#include "error.h" |
| 3 | +#include "libgen.h" |
| 4 | +#include "../env_config.h" |
2 | 5 | #include <string.h> |
3 | 6 |
|
4 | 7 | // Use this test if the network changes to ensure that the forward pass |
@@ -31,37 +34,58 @@ void test_drivenet() { |
31 | 34 | free(weights); |
32 | 35 | } |
33 | 36 |
|
34 | | -void demo() { |
| 37 | +int demo(const char *map_name, const char *policy_name, int show_grid, int obs_only, int lasers, int show_human_logs, |
| 38 | + int frame_skip, const char *view_mode, const char *output_topdown, const char *output_agent, int num_maps, |
| 39 | + int zoom_in) { |
35 | 40 |
|
36 | | - // Note: The settings below are hardcoded for demo purposes. Since the policy was |
37 | | - // trained with these exact settings, that changing them may lead to |
38 | | - // weird behavior. |
| 41 | + // Parse configuration from INI file |
| 42 | + env_init_config conf = {0}; |
| 43 | + const char *ini_file = "pufferlib/config/ocean/drive.ini"; |
| 44 | + if (ini_parse(ini_file, handler, &conf) < 0) { |
| 45 | + fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); |
| 46 | + return -1; |
| 47 | + } |
| 48 | + |
| 49 | + char map_buffer[100]; |
| 50 | + if (map_name == NULL) { |
| 51 | + srand(time(NULL)); |
| 52 | + int random_map = rand() % num_maps; |
| 53 | + sprintf(map_buffer, "%s/map_%03d.bin", conf.map_dir, random_map); |
| 54 | + map_name = map_buffer; |
| 55 | + } |
| 56 | + |
| 57 | + // Initialize environment with all config values from INI [env] section |
39 | 58 | Drive env = { |
40 | | - .human_agent_idx = 0, |
41 | | - .action_type = 0, // Discrete |
42 | | - .dynamics_model = CLASSIC, // Classic dynamics |
43 | | - .reward_vehicle_collision = -1.0f, |
44 | | - .reward_offroad_collision = -1.0f, |
45 | | - .reward_goal = 1.0f, |
46 | | - .reward_goal_post_respawn = 0.25f, |
47 | | - .goal_radius = 2.0f, |
48 | | - .goal_behavior = 1, |
49 | | - .goal_target_distance = 30.0f, |
50 | | - .goal_speed = 10.0f, |
51 | | - .dt = 0.1f, |
52 | | - .episode_length = 300, |
53 | | - .termination_mode = 0, |
54 | | - .collision_behavior = 0, |
55 | | - .offroad_behavior = 0, |
56 | | - .init_steps = 0, |
57 | | - .init_mode = 0, |
58 | | - .control_mode = 0, |
59 | | - .map_name = "resources/drive/map_town_02_carla.bin", |
| 59 | + .action_type = conf.action_type, |
| 60 | + .dynamics_model = conf.dynamics_model, |
| 61 | + .reward_vehicle_collision = conf.reward_vehicle_collision, |
| 62 | + .reward_offroad_collision = conf.reward_offroad_collision, |
| 63 | + .reward_goal = conf.reward_goal, |
| 64 | + .reward_goal_post_respawn = conf.reward_goal_post_respawn, |
| 65 | + .goal_radius = conf.goal_radius, |
| 66 | + .goal_behavior = conf.goal_behavior, |
| 67 | + .goal_target_distance = conf.goal_target_distance, |
| 68 | + .goal_speed = conf.goal_speed, |
| 69 | + .dt = conf.dt, |
| 70 | + .episode_length = conf.episode_length, |
| 71 | + .termination_mode = conf.termination_mode, |
| 72 | + .collision_behavior = conf.collision_behavior, |
| 73 | + .offroad_behavior = conf.offroad_behavior, |
| 74 | + .init_steps = conf.init_steps, |
| 75 | + .init_mode = conf.init_mode, |
| 76 | + .control_mode = conf.control_mode, |
| 77 | + .map_name = (char *)map_name, |
60 | 78 | }; |
61 | 79 | allocate(&env); |
| 80 | + if (env.active_agent_count == 0) { |
| 81 | + fprintf(stderr, "Error: No active agents found in map '%s' with init_mode=%d. Cannot run demo.\n", env.map_name, |
| 82 | + conf.init_mode); |
| 83 | + free_allocated(&env); |
| 84 | + return -1; |
| 85 | + } |
62 | 86 | c_reset(&env); |
63 | 87 | c_render(&env); |
64 | | - Weights *weights = load_weights("resources/drive/puffer_drive_weights_carla_town12.bin"); |
| 88 | + Weights *weights = load_weights((char *)policy_name); |
65 | 89 | DriveNet *net = init_drivenet(weights, env.active_agent_count, env.dynamics_model); |
66 | 90 |
|
67 | 91 | int accel_delta = 2; |
@@ -134,6 +158,7 @@ void demo() { |
134 | 158 | free_allocated(&env); |
135 | 159 | free_drivenet(net); |
136 | 160 | free(weights); |
| 161 | + return 0; |
137 | 162 | } |
138 | 163 |
|
139 | 164 | void performance_test() { |
@@ -177,9 +202,93 @@ void performance_test() { |
177 | 202 | free_allocated(&env); |
178 | 203 | } |
179 | 204 |
|
180 | | -int main() { |
| 205 | +int main(int argc, char *argv[]) { |
| 206 | + // Visualization-only parameters (not in [env] section) |
| 207 | + int show_grid = 0; |
| 208 | + int obs_only = 0; |
| 209 | + int lasers = 0; |
| 210 | + int show_human_logs = 0; |
| 211 | + int frame_skip = 1; |
| 212 | + int zoom_in = 0; |
| 213 | + const char *view_mode = "both"; |
| 214 | + |
| 215 | + // File paths and num_maps (not in [env] section) |
| 216 | + const char *map_name = NULL; |
| 217 | + const char *policy_name = "resources/drive/puffer_drive_weights.bin"; |
| 218 | + const char *output_topdown = NULL; |
| 219 | + const char *output_agent = NULL; |
| 220 | + int num_maps = 1; |
| 221 | + |
| 222 | + // Parse command line arguments |
| 223 | + for (int i = 1; i < argc; i++) { |
| 224 | + if (strcmp(argv[i], "--show-grid") == 0) { |
| 225 | + show_grid = 1; |
| 226 | + } else if (strcmp(argv[i], "--obs-only") == 0) { |
| 227 | + obs_only = 1; |
| 228 | + } else if (strcmp(argv[i], "--lasers") == 0) { |
| 229 | + lasers = 1; |
| 230 | + } else if (strcmp(argv[i], "--log-trajectories") == 0) { |
| 231 | + show_human_logs = 1; |
| 232 | + } else if (strcmp(argv[i], "--frame-skip") == 0) { |
| 233 | + if (i + 1 < argc) { |
| 234 | + frame_skip = atoi(argv[i + 1]); |
| 235 | + i++; |
| 236 | + if (frame_skip <= 0) { |
| 237 | + frame_skip = 1; |
| 238 | + } |
| 239 | + } |
| 240 | + } else if (strcmp(argv[i], "--zoom-in") == 0) { |
| 241 | + zoom_in = 1; |
| 242 | + } else if (strcmp(argv[i], "--view") == 0) { |
| 243 | + if (i + 1 < argc) { |
| 244 | + view_mode = argv[i + 1]; |
| 245 | + i++; |
| 246 | + if (strcmp(view_mode, "both") != 0 && strcmp(view_mode, "topdown") != 0 && |
| 247 | + strcmp(view_mode, "agent") != 0) { |
| 248 | + fprintf(stderr, "Error: --view must be 'both', 'topdown', or 'agent'\n"); |
| 249 | + return 1; |
| 250 | + } |
| 251 | + } else { |
| 252 | + fprintf(stderr, "Error: --view option requires a value (both/topdown/agent)\n"); |
| 253 | + return 1; |
| 254 | + } |
| 255 | + } else if (strcmp(argv[i], "--map-name") == 0) { |
| 256 | + if (i + 1 < argc) { |
| 257 | + map_name = argv[i + 1]; |
| 258 | + i++; |
| 259 | + } else { |
| 260 | + fprintf(stderr, "Error: --map-name option requires a map file path\n"); |
| 261 | + return 1; |
| 262 | + } |
| 263 | + } else if (strcmp(argv[i], "--policy-name") == 0) { |
| 264 | + if (i + 1 < argc) { |
| 265 | + policy_name = argv[i + 1]; |
| 266 | + i++; |
| 267 | + } else { |
| 268 | + fprintf(stderr, "Error: --policy-name option requires a policy file path\n"); |
| 269 | + return 1; |
| 270 | + } |
| 271 | + } else if (strcmp(argv[i], "--output-topdown") == 0) { |
| 272 | + if (i + 1 < argc) { |
| 273 | + output_topdown = argv[i + 1]; |
| 274 | + i++; |
| 275 | + } |
| 276 | + } else if (strcmp(argv[i], "--output-agent") == 0) { |
| 277 | + if (i + 1 < argc) { |
| 278 | + output_agent = argv[i + 1]; |
| 279 | + i++; |
| 280 | + } |
| 281 | + } else if (strcmp(argv[i], "--num-maps") == 0) { |
| 282 | + if (i + 1 < argc) { |
| 283 | + num_maps = atoi(argv[i + 1]); |
| 284 | + i++; |
| 285 | + } |
| 286 | + } |
| 287 | + } |
| 288 | + |
181 | 289 | // performance_test(); |
182 | | - demo(); |
| 290 | + demo(map_name, policy_name, show_grid, obs_only, lasers, show_human_logs, frame_skip, view_mode, output_topdown, |
| 291 | + output_agent, num_maps, zoom_in); |
183 | 292 | // test_drivenet(); |
184 | 293 | return 0; |
185 | 294 | } |
0 commit comments