Skip to content

Commit 4718fd4

Browse files
committed
Simplify Four Rooms native integration
1 parent de05e0b commit 4718fd4

3 files changed

Lines changed: 25 additions & 69 deletions

File tree

ocean/four_rooms/binding.c

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,16 @@
55
#define ACT_SIZES {FOUR_ROOMS_NUM_ACTIONS}
66
#define OBS_TENSOR_T ByteTensor
77

8-
#define MY_VEC_STEP four_rooms_vec_step
9-
#define MY_VEC_STEP_RANGE four_rooms_vec_step_range
108
#define Env FourRooms
119
#include "vecenv.h"
1210

13-
void four_rooms_vec_step(StaticVec* vec) {
14-
memset(vec->rewards, 0, vec->total_agents * sizeof(float));
15-
memset(vec->terminals, 0, vec->total_agents * sizeof(float));
16-
FourRooms* envs = (FourRooms*)vec->envs;
17-
for (int i = 0; i < vec->size; i++) {
18-
c_step(&envs[i]);
19-
}
20-
}
21-
22-
void four_rooms_vec_step_range(StaticVec* vec, int env_start, int env_count, int num_workers) {
23-
(void)num_workers;
24-
FourRooms* envs = (FourRooms*)vec->envs;
25-
for (int i = env_start; i < env_start + env_count; i++) {
26-
c_step(&envs[i]);
27-
}
28-
}
29-
3011
void my_init(Env* env, Dict* kwargs) {
3112
env->num_agents = 1;
3213
env->size = (int)dict_get(kwargs, "size")->value;
3314
env->max_steps = (int)dict_get(kwargs, "max_steps")->value;
3415
if (env->max_steps <= 0) {
35-
env->max_steps = 4 * env->size;
16+
env->max_steps = FOUR_ROOMS_TIMEOUT_SCALE * env->size;
3617
}
37-
env->see_through_walls = 0;
3818
env->grid = (unsigned char*)calloc(env->size * env->size, sizeof(unsigned char));
3919
}
4020

ocean/four_rooms/four_rooms.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ int main() {
1919
c_render(&env);
2020
while (!WindowShouldClose()) {
2121
if (IsKeyDown(KEY_LEFT_SHIFT)) {
22-
env.actions[0] = 7; // Invalid action = no-op
22+
env.actions[0] = DONE;
2323
if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env.actions[0] = FORWARD;
2424
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) env.actions[0] = LEFT;
2525
if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) env.actions[0] = RIGHT;
2626
} else {
27-
env.actions[0] = four_rooms_rand(&env, 3); // Only use left, right, forward
27+
env.actions[0] = four_rooms_rand(&env, 3);
2828
}
2929
c_step(&env);
3030
c_render(&env);

ocean/four_rooms/four_rooms.h

Lines changed: 22 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#define FOUR_ROOMS_VIEW_SIZE 7
66
#define FOUR_ROOMS_OBS_CHANNELS 3
77
#define FOUR_ROOMS_NUM_ACTIONS 7
8+
#define FOUR_ROOMS_TIMEOUT_SCALE 4
89

9-
// Action space
1010
enum {
1111
LEFT = 0,
1212
RIGHT = 1,
@@ -17,7 +17,6 @@ enum {
1717
DONE = 6,
1818
};
1919

20-
// Observation: Objects
2120
enum {
2221
UNSEEN = 0,
2322
EMPTY = 1,
@@ -26,14 +25,12 @@ enum {
2625
AGENT = 10,
2726
};
2827

29-
// Observation: Colors
3028
enum {
3129
COLOR_BLACK = 0,
3230
COLOR_GREEN = 1,
3331
COLOR_GREY = 5,
3432
};
3533

36-
// PufferLib standard colors for rendering
3734
static const Color PUFF_RED = (Color){187, 0, 0, 255};
3835
static const Color PUFF_BACKGROUND = (Color){6, 24, 24, 255};
3936
static const Color PUFF_BACKGROUND2 = (Color){18, 72, 72, 255};
@@ -61,7 +58,6 @@ typedef struct {
6158
int agent_dir;
6259
int goal_x, goal_y;
6360
unsigned char* grid;
64-
int see_through_walls;
6561
unsigned int rng;
6662
int texture_loaded;
6763
Texture2D puffers;
@@ -71,7 +67,7 @@ static inline int four_rooms_rand(FourRooms* env, int n) {
7167
return rand_r(&env->rng) % n;
7268
}
7369

74-
static inline int four_rooms_grid_idx(FourRooms* env, int x, int y) {
70+
static inline int grid_idx(FourRooms* env, int x, int y) {
7571
return y * env->size + x;
7672
}
7773

@@ -83,7 +79,8 @@ void add_log(FourRooms* env) {
8379
env->log.n++;
8480
}
8581

86-
void encode_cell(unsigned char object, unsigned char* object_idx, unsigned char* color_idx, unsigned char* state) {
82+
static inline void encode_cell(unsigned char object, unsigned char* object_idx,
83+
unsigned char* color_idx, unsigned char* state) {
8784
*state = 0;
8885
if (object == WALL) {
8986
*object_idx = WALL;
@@ -97,7 +94,8 @@ void encode_cell(unsigned char object, unsigned char* object_idx, unsigned char*
9794
}
9895
}
9996

100-
void observation_to_world(FourRooms* env, int obs_x, int obs_y, int* world_x, int* world_y) {
97+
static inline void observation_to_world(FourRooms* env, int obs_x, int obs_y,
98+
int* world_x, int* world_y) {
10199
int forward_x = 0;
102100
int forward_y = 0;
103101
if (env->agent_dir == 0) forward_x = 1;
@@ -114,7 +112,7 @@ void observation_to_world(FourRooms* env, int obs_x, int obs_y, int* world_x, in
114112
*world_y = env->agent_y + forward_y * forward_offset + right_y * right_offset;
115113
}
116114

117-
void compute_visibility(unsigned char view[FOUR_ROOMS_VIEW_SIZE][FOUR_ROOMS_VIEW_SIZE],
115+
static inline void compute_visibility(unsigned char view[FOUR_ROOMS_VIEW_SIZE][FOUR_ROOMS_VIEW_SIZE],
118116
unsigned char visible[FOUR_ROOMS_VIEW_SIZE][FOUR_ROOMS_VIEW_SIZE]) {
119117
memset(visible, 0, FOUR_ROOMS_VIEW_SIZE * FOUR_ROOMS_VIEW_SIZE * sizeof(unsigned char));
120118
visible[FOUR_ROOMS_VIEW_SIZE - 1][FOUR_ROOMS_VIEW_SIZE / 2] = 1;
@@ -158,16 +156,12 @@ void generate_observation(FourRooms* env) {
158156
} else if (world_x == env->agent_x && world_y == env->agent_y) {
159157
view[y][x] = EMPTY;
160158
} else {
161-
view[y][x] = env->grid[four_rooms_grid_idx(env, world_x, world_y)];
159+
view[y][x] = env->grid[grid_idx(env, world_x, world_y)];
162160
}
163161
}
164162
}
165163

166-
if (env->see_through_walls) {
167-
memset(visible, 1, FOUR_ROOMS_VIEW_SIZE * FOUR_ROOMS_VIEW_SIZE * sizeof(unsigned char));
168-
} else {
169-
compute_visibility(view, visible);
170-
}
164+
compute_visibility(view, visible);
171165

172166
for (int y = 0; y < FOUR_ROOMS_VIEW_SIZE; y++) {
173167
for (int x = 0; x < FOUR_ROOMS_VIEW_SIZE; x++) {
@@ -192,73 +186,61 @@ void generate_observation(FourRooms* env) {
192186
void create_four_rooms_grid(FourRooms* env) {
193187
int size = env->size;
194188

195-
// Clear grid
196189
memset(env->grid, EMPTY, size * size * sizeof(unsigned char));
197190

198-
// Create outer walls
199191
for (int i = 0; i < size; i++) {
200-
env->grid[0 * size + i] = WALL; // Top
201-
env->grid[(size-1) * size + i] = WALL; // Bottom
202-
env->grid[i * size + 0] = WALL; // Left
203-
env->grid[i * size + (size-1)] = WALL; // Right
192+
env->grid[i] = WALL;
193+
env->grid[(size - 1) * size + i] = WALL;
194+
env->grid[i * size] = WALL;
195+
env->grid[i * size + size - 1] = WALL;
204196
}
205197

206198
int room_w = size / 2;
207199
int room_h = size / 2;
208200

209-
// Create vertical separating wall
210201
for (int y = 0; y < size; y++) {
211202
env->grid[y * size + room_w] = WALL;
212203
}
213204

214-
// Create horizontal separating wall
215205
for (int x = 0; x < size; x++) {
216206
env->grid[room_h * size + x] = WALL;
217207
}
218208

219-
// Create 4 gaps in the separating walls
220-
// Gap in vertical wall (top half)
209+
// MiniGrid samples doorway positions from [start + 1, end).
221210
int gap_y1 = 1 + four_rooms_rand(env, room_h - 1);
222211
env->grid[gap_y1 * size + room_w] = EMPTY;
223212

224-
// Gap in vertical wall (bottom half)
225213
int gap_y2 = room_h + 1 + four_rooms_rand(env, room_h - 1);
226214
env->grid[gap_y2 * size + room_w] = EMPTY;
227215

228-
// Gap in horizontal wall (left half)
229216
int gap_x1 = 1 + four_rooms_rand(env, room_w - 1);
230217
env->grid[room_h * size + gap_x1] = EMPTY;
231218

232-
// Gap in horizontal wall (right half)
233219
int gap_x2 = room_w + 1 + four_rooms_rand(env, room_w - 1);
234220
env->grid[room_h * size + gap_x2] = EMPTY;
235221
}
236222

237223
void c_reset(FourRooms* env) {
238224
if (env->max_steps <= 0) {
239-
env->max_steps = 4 * env->size;
225+
env->max_steps = FOUR_ROOMS_TIMEOUT_SCALE * env->size;
240226
}
241227

242228
create_four_rooms_grid(env);
243229

244-
// Place agent randomly in valid position
245230
do {
246231
env->agent_x = 1 + four_rooms_rand(env, env->size - 2);
247232
env->agent_y = 1 + four_rooms_rand(env, env->size - 2);
248-
} while (env->grid[four_rooms_grid_idx(env, env->agent_x, env->agent_y)] != EMPTY);
233+
} while (env->grid[grid_idx(env, env->agent_x, env->agent_y)] != EMPTY);
249234

250-
// Place goal randomly in valid position (different from agent)
251235
do {
252236
env->goal_x = 1 + four_rooms_rand(env, env->size - 2);
253237
env->goal_y = 1 + four_rooms_rand(env, env->size - 2);
254-
} while (env->grid[four_rooms_grid_idx(env, env->goal_x, env->goal_y)] != EMPTY ||
238+
} while (env->grid[grid_idx(env, env->goal_x, env->goal_y)] != EMPTY ||
255239
(env->goal_x == env->agent_x && env->goal_y == env->agent_y));
256240

257-
// Set agent and goal on grid
258-
env->grid[four_rooms_grid_idx(env, env->agent_x, env->agent_y)] = AGENT;
259-
env->grid[four_rooms_grid_idx(env, env->goal_x, env->goal_y)] = GOAL;
241+
env->grid[grid_idx(env, env->agent_x, env->agent_y)] = AGENT;
242+
env->grid[grid_idx(env, env->goal_x, env->goal_y)] = GOAL;
260243

261-
// Random initial direction
262244
env->agent_dir = four_rooms_rand(env, 4);
263245
env->tick = 0;
264246
env->episode_return = 0.0f;
@@ -273,8 +255,7 @@ void c_step(FourRooms* env) {
273255
env->terminals[0] = 0;
274256
env->rewards[0] = 0.0;
275257

276-
// Clear agent from current position
277-
env->grid[four_rooms_grid_idx(env, env->agent_x, env->agent_y)] = EMPTY;
258+
env->grid[grid_idx(env, env->agent_x, env->agent_y)] = EMPTY;
278259

279260
int new_x = env->agent_x;
280261
int new_y = env->agent_y;
@@ -290,17 +271,15 @@ void c_step(FourRooms* env) {
290271
else if (env->agent_dir == 2) new_x -= 1;
291272
else if (env->agent_dir == 3) new_y -= 1;
292273

293-
// Check if move is valid
294274
if (new_x >= 0 && new_x < env->size && new_y >= 0 && new_y < env->size &&
295-
env->grid[four_rooms_grid_idx(env, new_x, new_y)] != WALL) {
275+
env->grid[grid_idx(env, new_x, new_y)] != WALL) {
296276
env->agent_x = new_x;
297277
env->agent_y = new_y;
298278
}
299279
}
300280

301281
env->agent_dir = new_dir;
302282

303-
// Check if agent reached goal
304283
if (env->agent_x == env->goal_x && env->agent_y == env->goal_y) {
305284
env->terminals[0] = 1;
306285
env->rewards[0] = 1.0f - 0.9f * (float)env->tick / (float)env->max_steps;
@@ -310,10 +289,8 @@ void c_step(FourRooms* env) {
310289
return;
311290
}
312291

313-
// Place agent back on grid
314-
env->grid[four_rooms_grid_idx(env, env->agent_x, env->agent_y)] = AGENT;
292+
env->grid[grid_idx(env, env->agent_x, env->agent_y)] = AGENT;
315293

316-
// Check timeout
317294
if (env->tick >= env->max_steps) {
318295
env->terminals[0] = 1;
319296
env->rewards[0] = 0.0;
@@ -344,7 +321,6 @@ void c_render(FourRooms* env) {
344321

345322
int px = 32;
346323

347-
// Draw the main grid
348324
for (int y = 0; y < env->size; y++) {
349325
for (int x = 0; x < env->size; x++) {
350326
int cell = env->grid[y * env->size + x];

0 commit comments

Comments
 (0)