Skip to content

Commit b6ed82f

Browse files
daphne-cornelisseDaphne
andauthored
Allow human to drive with agents through classic and jerk dynamics model (#206)
* Fix human control with joint action space & classic model: Was still assuming multi-discrete. * Enable human control with jerks dynamics model. * Color actions yellow when controlling. * Slightly easier control problem? * Add tiny jerk penalty: Results in smooth behavior. * Pre-commit * Minor edits. * Revert ini changes. --------- Co-authored-by: Daphne <daphn3cor@gmail.com>
1 parent 9a58142 commit b6ed82f

4 files changed

Lines changed: 167 additions & 45 deletions

File tree

docs/interact-with-agents.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
## Drive with trained agents
2+
3+
You can take manual control of an agent in the simulator by holding **LEFT SHIFT** and using the keyboard controls. When you're in control, the action values displayed on screen will turn **yellow**.
4+
5+
### Local rendering
6+
7+
To launch an interactive renderer, first build:
8+
```bash
9+
bash scripts/build_ocean.sh drive local
10+
```
11+
12+
then launch:
13+
```bash
14+
./drive
15+
```
16+
17+
This will run `demo()` with an existing model checkpoint.
18+
19+
[TODO: Add demo video/gif here]
20+
21+
### Controls
22+
23+
**General:**
24+
- **LEFT SHIFT + Arrow Keys/WASD** - Take manual control
25+
- **TAB** - Switch between agents
26+
- **SPACE** - First-person camera view
27+
- **Mouse Drag** - Pan camera
28+
- **Mouse Wheel** - Zoom
29+
30+
**Classic dynamics model**
31+
32+
- **SHIFT + UP/W** - Increase acceleration
33+
- **SHIFT + DOWN/S** - Decrease acceleration (brake)
34+
- **SHIFT + LEFT/A** - Steer left
35+
- **SHIFT + RIGHT/D** - Steer right
36+
37+
Each key press increments or decrements the action level. For example, tapping W multiple times increases acceleration from neutral (index 3) → 5 → 6 (maximum acceleration). We assume **no friction**, so releasing all keys maintains constant speed and heading.
38+
39+
**Jerk dynamics model**
40+
41+
- **SHIFT + UP/W** - Accelerate (+4.0 m/s³ jerk)
42+
- **SHIFT + DOWN/S** - Brake (-15.0 m/s³ jerk)
43+
- **SHIFT + LEFT/A** - Turn left (+4.0 m/s³ lateral jerk)
44+
- **SHIFT + RIGHT/D** - Turn right (-4.0 m/s³ lateral jerk)
45+
46+
Actions are applied directly when keys are pressed. Pressing W always applies +4.0 m/s³ longitudinal jerk, regardless of how long the key is held.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ nav:
2828
- Docs:
2929
- Getting started: getting-started.md
3030
- Training agents: train.md
31+
- Interact with agents: interact-with-agents.md
3132
- Simulator: simulator.md
3233
- Interactive scenario editor: scene-editor.md
3334
- Visualizer: visualizer.md

pufferlib/ocean/drive/drive.c

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,44 +71,65 @@ void demo() {
7171
int accel_delta = 2;
7272
int steer_delta = 4;
7373
while (!WindowShouldClose()) {
74-
// Handle camera controls
75-
int (*actions)[2] = (int (*)[2])env.actions;
74+
int *actions = (int *)env.actions; // Single integer per agent
75+
7676
forward(net, env.observations, env.actions);
77+
7778
if (IsKeyDown(KEY_LEFT_SHIFT)) {
78-
actions[env.human_agent_idx][0] = 3;
79-
actions[env.human_agent_idx][1] = 6;
80-
if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) {
81-
actions[env.human_agent_idx][0] += accel_delta;
82-
// Cap acceleration to maximum of 6
83-
if (actions[env.human_agent_idx][0] > 6) {
84-
actions[env.human_agent_idx][0] = 6;
79+
if (env.dynamics_model == CLASSIC) {
80+
// Classic dynamics: acceleration and steering
81+
int accel_idx = 3; // neutral (0 m/s²)
82+
int steer_idx = 6; // neutral (0.0 steering)
83+
84+
if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) {
85+
accel_idx += accel_delta;
86+
if (accel_idx > 6)
87+
accel_idx = 6;
8588
}
86-
}
87-
if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) {
88-
actions[env.human_agent_idx][0] -= accel_delta;
89-
// Cap acceleration to minimum of 0
90-
if (actions[env.human_agent_idx][0] < 0) {
91-
actions[env.human_agent_idx][0] = 0;
89+
if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) {
90+
accel_idx -= accel_delta;
91+
if (accel_idx < 0)
92+
accel_idx = 0;
9293
}
93-
}
94-
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) {
95-
actions[env.human_agent_idx][1] += steer_delta;
96-
// Cap steering to minimum of 0
97-
if (actions[env.human_agent_idx][1] < 0) {
98-
actions[env.human_agent_idx][1] = 0;
94+
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) {
95+
steer_idx += steer_delta; // Increase steering index for left turn
96+
if (steer_idx > 12)
97+
steer_idx = 12;
9998
}
100-
}
101-
if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) {
102-
actions[env.human_agent_idx][1] -= steer_delta;
103-
// Cap steering to maximum of 12
104-
if (actions[env.human_agent_idx][1] > 12) {
105-
actions[env.human_agent_idx][1] = 12;
99+
if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) {
100+
steer_idx -= steer_delta; // Decrease steering index for right turn
101+
if (steer_idx < 0)
102+
steer_idx = 0;
106103
}
107-
}
108-
if (IsKeyPressed(KEY_TAB)) {
109-
env.human_agent_idx = (env.human_agent_idx + 1) % env.active_agent_count;
104+
105+
// Encode into single integer: action = accel_idx * 13 + steer_idx
106+
actions[env.human_agent_idx] = accel_idx * 13 + steer_idx;
107+
108+
} else if (env.dynamics_model == JERK) {
109+
// Jerk dynamics: longitudinal and lateral jerk
110+
// JERK_LONG[4] = {-15.0f, -4.0f, 0.0f, 4.0f}
111+
// JERK_LAT[3] = {-4.0f, 0.0f, 4.0f}
112+
int jerk_long_idx = 2; // neutral (0.0)
113+
int jerk_lat_idx = 1; // neutral (0.0)
114+
115+
if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) {
116+
jerk_long_idx = 3; // acceleration (4.0)
117+
}
118+
if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) {
119+
jerk_long_idx = 0; // hard braking (-15.0)
120+
}
121+
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) {
122+
jerk_lat_idx = 2; // left turn (4.0)
123+
}
124+
if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) {
125+
jerk_lat_idx = 0; // right turn (-4.0)
126+
}
127+
128+
// Encode into single integer: action = jerk_long_idx * 3 + jerk_lat_idx
129+
actions[env.human_agent_idx] = jerk_long_idx * 3 + jerk_lat_idx;
110130
}
111131
}
132+
112133
c_step(&env);
113134
c_render(&env);
114135
}

pufferlib/ocean/drive/drive.h

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,8 +2020,22 @@ void c_step(Drive *env) {
20202020
env->logs[i].episode_length += 1;
20212021
int agent_idx = env->active_agent_indices[i];
20222022
env->entities[agent_idx].collision_state = 0;
2023+
float prev_vx = env->entities[agent_idx].vx;
2024+
float prev_vy = env->entities[agent_idx].vy;
2025+
20232026
move_dynamics(env, i, agent_idx);
2027+
2028+
// Tiny jerk penalty for smoothness
2029+
if (env->dynamics_model == CLASSIC) {
2030+
float delta_vx = env->entities[agent_idx].vx - prev_vx;
2031+
float delta_vy = env->entities[agent_idx].vy - prev_vy;
2032+
float jerk_penalty = -0.0002f * sqrtf(delta_vx * delta_vx + delta_vy * delta_vy) / env->dt;
2033+
env->rewards[i] += jerk_penalty;
2034+
env->logs[i].episode_return += jerk_penalty;
2035+
}
20242036
}
2037+
2038+
// Compute rewards
20252039
for (int i = 0; i < env->active_agent_count; i++) {
20262040
int agent_idx = env->active_agent_indices[i];
20272041
env->entities[agent_idx].collision_state = 0;
@@ -2849,6 +2863,7 @@ void c_render(Drive *env) {
28492863
BeginMode3D(client->camera);
28502864
handle_camera_controls(env->client);
28512865
draw_scene(env, client, 0, 0, 0, 0);
2866+
28522867
// Draw debug info
28532868
DrawText(TextFormat("Camera Position: (%.2f, %.2f, %.2f)", client->camera.position.x, client->camera.position.y,
28542869
client->camera.position.z),
@@ -2857,25 +2872,64 @@ void c_render(Drive *env) {
28572872
client->camera.target.z),
28582873
10, 30, 20, PUFF_WHITE);
28592874
DrawText(TextFormat("Timestep: %d", env->timestep), 10, 50, 20, PUFF_WHITE);
2860-
// acceleration & steering
2875+
28612876
int human_idx = env->active_agent_indices[env->human_agent_idx];
28622877
DrawText(TextFormat("Controlling Agent: %d", env->human_agent_idx), 10, 70, 20, PUFF_WHITE);
28632878
DrawText(TextFormat("Agent Index: %d", human_idx), 10, 90, 20, PUFF_WHITE);
2864-
// Controls help
2865-
DrawText("Controls: W/S - Accelerate/Brake, A/D - Steer, 1-4 - Switch Agent", 10, client->height - 30, 20,
2866-
PUFF_WHITE);
2867-
// acceleration & steering
2868-
if (env->action_type == 1) { // continuous (float)
2879+
2880+
// Display current action values - yellow when controlling, white otherwise
2881+
Color action_color = IsKeyDown(KEY_LEFT_SHIFT) ? YELLOW : PUFF_WHITE;
2882+
2883+
if (env->action_type == 0) { // discrete
2884+
int *action_array = (int *)env->actions;
2885+
int action_val = action_array[env->human_agent_idx];
2886+
2887+
if (env->dynamics_model == CLASSIC) {
2888+
int num_steer = 13;
2889+
int accel_idx = action_val / num_steer;
2890+
int steer_idx = action_val % num_steer;
2891+
float accel_value = ACCELERATION_VALUES[accel_idx];
2892+
float steer_value = STEERING_VALUES[steer_idx];
2893+
2894+
DrawText(TextFormat("Acceleration: %.2f m/s^2", accel_value), 10, 110, 20, action_color);
2895+
DrawText(TextFormat("Steering: %.3f", steer_value), 10, 130, 20, action_color);
2896+
} else if (env->dynamics_model == JERK) {
2897+
int num_lat = 3;
2898+
int jerk_long_idx = action_val / num_lat;
2899+
int jerk_lat_idx = action_val % num_lat;
2900+
float jerk_long_value = JERK_LONG[jerk_long_idx];
2901+
float jerk_lat_value = JERK_LAT[jerk_lat_idx];
2902+
2903+
DrawText(TextFormat("Longitudinal Jerk: %.2f m/s^3", jerk_long_value), 10, 110, 20, action_color);
2904+
DrawText(TextFormat("Lateral Jerk: %.2f m/s^3", jerk_lat_value), 10, 130, 20, action_color);
2905+
}
2906+
} else { // continuous
28692907
float (*action_array_f)[2] = (float (*)[2])env->actions;
2870-
DrawText(TextFormat("Acceleration: %.2f", action_array_f[env->human_agent_idx][0]), 10, 110, 20, PUFF_WHITE);
2871-
DrawText(TextFormat("Steering: %.2f", action_array_f[env->human_agent_idx][1]), 10, 130, 20, PUFF_WHITE);
2872-
} else { // discrete (int)
2873-
int (*action_array)[2] = (int (*)[2])env->actions;
2874-
DrawText(TextFormat("Acceleration: %d", action_array[env->human_agent_idx][0]), 10, 110, 20, PUFF_WHITE);
2875-
DrawText(TextFormat("Steering: %d", action_array[env->human_agent_idx][1]), 10, 130, 20, PUFF_WHITE);
2876-
}
2877-
DrawText(TextFormat("Grid Rows: %d", env->grid_map->grid_rows), 10, 150, 20, PUFF_WHITE);
2878-
DrawText(TextFormat("Grid Cols: %d", env->grid_map->grid_cols), 10, 170, 20, PUFF_WHITE);
2908+
DrawText(TextFormat("Acceleration: %.2f", action_array_f[env->human_agent_idx][0]), 10, 110, 20, action_color);
2909+
DrawText(TextFormat("Steering: %.2f", action_array_f[env->human_agent_idx][1]), 10, 130, 20, action_color);
2910+
}
2911+
2912+
// Show key press status
2913+
int status_y = 150;
2914+
if (IsKeyDown(KEY_LEFT_SHIFT)) {
2915+
DrawText("[shift pressed]", 10, status_y, 20, YELLOW);
2916+
status_y += 20;
2917+
}
2918+
if (IsKeyDown(KEY_SPACE)) {
2919+
DrawText("[space pressed]", 10, status_y, 20, YELLOW);
2920+
status_y += 20;
2921+
}
2922+
if (IsKeyDown(KEY_LEFT_CONTROL)) {
2923+
DrawText("[ctrl pressed]", 10, status_y, 20, YELLOW);
2924+
status_y += 20;
2925+
}
2926+
2927+
// Controls help
2928+
DrawText("Controls: SHIFT + W/S - Accelerate/Brake, SHIFT + A/D - Steer, TAB - Switch Agent", 10,
2929+
client->height - 30, 20, PUFF_WHITE);
2930+
2931+
DrawText(TextFormat("Grid Rows: %d", env->grid_map->grid_rows), 10, status_y, 20, PUFF_WHITE);
2932+
DrawText(TextFormat("Grid Cols: %d", env->grid_map->grid_cols), 10, status_y + 20, 20, PUFF_WHITE);
28792933
EndDrawing();
28802934
}
28812935

0 commit comments

Comments
 (0)