Skip to content

Commit 0bd3336

Browse files
committed
Add support for "dodgeball" ; unify some RNG functionality
1 parent d5ab846 commit 0bd3336

4 files changed

Lines changed: 32 additions & 9 deletions

File tree

procgen/env.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@
5353
#################################################
5454

5555
ENV_NAMES_OOD = [
56-
"coinrun",
57-
"climber",
56+
"bigfish",
57+
"bossfight",
58+
"caveflyer",
59+
# "chaser",
60+
"climber",
61+
"coinrun",
62+
"dodgeball",
5863
# Add more as they are finished
5964
]
6065

procgen/src/basic-abstract-game.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,16 @@ void BasicAbstractGame::choose_random_theme_type_match(const std::shared_ptr<Ent
11211121
}
11221122
}
11231123

1124+
int BasicAbstractGame::get_random_theme_type_match(const std::shared_ptr<Entity> &ent, const std::string &var_type) {
1125+
// Get a random theme integer based on the type of variable and train/eval environment.
1126+
// Useful for setting multiple entities to the same theme.
1127+
if (type_match(var_type)) {
1128+
return randn_type_switch(asset_num_themes[ent->image_type], var_type);
1129+
} else {
1130+
return rand_gen.randn(asset_num_themes[ent->image_type]);
1131+
}
1132+
}
1133+
11241134
void BasicAbstractGame::choose_random_theme_switch(const std::shared_ptr<Entity> &ent) {
11251135
// Dispatch method for choosing a random theme
11261136
if (eval_env) {
@@ -1142,8 +1152,8 @@ void BasicAbstractGame::choose_random_theme_train(const std::shared_ptr<Entity>
11421152
return;
11431153
}
11441154

1145-
float ithf = 1 - train_holdout_frac;
1146-
float theme_frac = num_themes * ithf;
1155+
float inv_frac = 1 - train_holdout_frac;
1156+
float theme_frac = num_themes * inv_frac;
11471157
int num_train_themes = std::max((int)theme_frac, 1); // Ensure at least one theme is chosen
11481158
ent->image_theme = rand_gen.randn(num_train_themes);
11491159

@@ -1170,8 +1180,7 @@ void BasicAbstractGame::choose_random_theme_eval(const std::shared_ptr<Entity> &
11701180
}
11711181
return;
11721182
}
1173-
float ehf = eval_holdout_frac;
1174-
float theme_frac = num_themes * ehf;
1183+
float theme_frac = num_themes * eval_holdout_frac;
11751184
int num_eval_themes = std::max((int)theme_frac, 1);
11761185
int start_idx = num_themes - num_eval_themes;
11771186
ent->image_theme = start_idx + rand_gen.randn(num_eval_themes);

procgen/src/basic-abstract-game.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,11 @@ class BasicAbstractGame : public Game {
112112
bool agent_has_collision();
113113
void reposition_agent();
114114

115+
void initialize_asset_if_necessary(int img_idx);
116+
115117
// Added for generalization testing -------------------------------------------//
116118
bool type_match(std::string var_type);
119+
int get_random_theme_type_match(const std::shared_ptr<Entity> &ent, const std::string &var_type);
117120
void choose_random_theme_switch(const std::shared_ptr<Entity> &ent);
118121
void choose_random_theme_train(const std::shared_ptr<Entity> &ent);
119122
void choose_random_theme_eval(const std::shared_ptr<Entity> &ent);
@@ -180,7 +183,6 @@ class BasicAbstractGame : public Game {
180183
Grid<int> grid;
181184

182185
QImage *lookup_asset(int img_idx, bool is_reflected = false);
183-
void initialize_asset_if_necessary(int img_idx);
184186
void prepare_for_drawing(float rect_height);
185187
void draw_background(QPainter &p, const QRect &rect);
186188
void draw_entity(QPainter &p, const std::shared_ptr<Entity> &to_draw);

procgen/src/games/dodgeball.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ const int DUST_CLOUD = 8;
1818
const int OOB_WALL = 10;
1919

2020
const int ENEMY_REWARD = 2.0f;
21-
const int NUM_ENEMY_THEMES = 7;
2221

2322
const float ENEMY_VEL = 0.05f;
2423
const float BALL_V_ROT = PI * 0.23f;
@@ -348,7 +347,15 @@ class DodgeballGame : public BasicAbstractGame {
348347

349348
spawn_entities(num_enemies, enemy_r, ENEMY, 0, 0, main_width, main_height);
350349

351-
int enemy_theme = rand_gen.randn(NUM_ENEMY_THEMES);
350+
// Choose a random theme for all enemies
351+
int enemy_theme;
352+
for (auto ent : entities) {
353+
if (ent->type == ENEMY) {
354+
initialize_asset_if_necessary(ent->image_type);
355+
enemy_theme = get_random_theme_type_match(ent, "enemy");
356+
break;
357+
}
358+
}
352359

353360
for (auto ent : entities) {
354361
if (ent->type == ENEMY) {

0 commit comments

Comments
 (0)