@@ -1121,6 +1121,16 @@ void BasicAbstractGame::choose_random_theme_type_match(const std::shared_ptr<Ent
11211121 }
11221122}
11231123
1124+ int BasicAbstractGame::get_random_theme_type_match (const std::shared_ptr<Entity> &ent, const std::string &var_type) {
1125+ // Get a random theme integer based on the type of variable and train/eval environment.
1126+ // Useful for setting multiple entities to the same theme.
1127+ if (type_match (var_type)) {
1128+ return randn_type_switch (asset_num_themes[ent->image_type ], var_type);
1129+ } else {
1130+ return rand_gen.randn (asset_num_themes[ent->image_type ]);
1131+ }
1132+ }
1133+
11241134void BasicAbstractGame::choose_random_theme_switch (const std::shared_ptr<Entity> &ent) {
11251135 // Dispatch method for choosing a random theme
11261136 if (eval_env) {
@@ -1142,8 +1152,8 @@ void BasicAbstractGame::choose_random_theme_train(const std::shared_ptr<Entity>
11421152 return ;
11431153 }
11441154
1145- float ithf = 1 - train_holdout_frac;
1146- float theme_frac = num_themes * ithf ;
1155+ float inv_frac = 1 - train_holdout_frac;
1156+ float theme_frac = num_themes * inv_frac ;
11471157 int num_train_themes = std::max ((int )theme_frac, 1 ); // Ensure at least one theme is chosen
11481158 ent->image_theme = rand_gen.randn (num_train_themes);
11491159
@@ -1170,8 +1180,7 @@ void BasicAbstractGame::choose_random_theme_eval(const std::shared_ptr<Entity> &
11701180 }
11711181 return ;
11721182 }
1173- float ehf = eval_holdout_frac;
1174- float theme_frac = num_themes * ehf;
1183+ float theme_frac = num_themes * eval_holdout_frac;
11751184 int num_eval_themes = std::max ((int )theme_frac, 1 );
11761185 int start_idx = num_themes - num_eval_themes;
11771186 ent->image_theme = start_idx + rand_gen.randn (num_eval_themes);
0 commit comments