Skip to content

Commit ef90154

Browse files
committed
ocean/craftax: update_mobs early-out on dead mob slots
The five move_* helpers (melee/passive/ranged mobs + mob/player projectiles) now return immediately when mask=false. JAX's branchless "compute-then-mask" pattern is pointless on CPU: dead slots' output never feeds observations, rewards, or mob_map, so skipping the body and the RNG draws is semantically equivalent. Defining CRAFTAX_JAX_PARITY at build time restores the branchless slow path for bitwise replay against JAX (required by tests/craftax_parity.py). Default build uses the early-out. Also drops craftax_step_jax_index(player_level, NUM_LEVELS) clamps at the top of each move_* -- state->player_level is maintained in [0, NUM_LEVELS-1] by change_floor_native (explicit bounds checks) and by the worldgen init. Six redundant clamps per step eliminated. Measurements (single-thread, random actions, pool=1024): update_mobs phase: 1.392 us -> 0.285 us (4.88x) full c_step: 2.35 us -> 1.22 us 1-thread sim SPS: 425K -> 819K (1.93x) 16-thread sim SPS: 5.53M -> 10.04M (1.82x) training SPS: 506K -> 544K (+7%) Parity test with CRAFTAX_JAX_PARITY defined passes 8 seeds * 1000 steps over 27 terminals. Without the flag, parity diverges at the first mob death -- by design.
1 parent 93cfb01 commit ef90154

1 file changed

Lines changed: 29 additions & 25 deletions

File tree

ocean/craftax/step_update_mobs.h

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -600,13 +600,17 @@ static inline void craftax_update_mobs_move_melee(
600600
CraftaxThreefryKey* rng,
601601
int32_t index
602602
) {
603-
int32_t level = craftax_step_jax_index(
604-
state->player_level,
605-
CRAFTAX_NUM_LEVELS
606-
);
603+
int32_t level = state->player_level;
604+
bool old_mask = state->melee_mobs.mask[level][index];
605+
// Dead slot early-out: no observable effect on obs/reward/terminal.
606+
// Skip body and RNG draws for speed. Breaks per-seed replay against
607+
// JAX; define CRAFTAX_JAX_PARITY at build time to restore the
608+
// branchless slow path (same pattern in every move_* below).
609+
#ifndef CRAFTAX_JAX_PARITY
610+
if (!old_mask) return;
611+
#endif
607612
int32_t old_row = state->melee_mobs.position[level][index][0];
608613
int32_t old_col = state->melee_mobs.position[level][index][1];
609-
bool old_mask = state->melee_mobs.mask[level][index];
610614
int32_t old_cooldown = state->melee_mobs.attack_cooldown[level][index];
611615
int32_t mob_type = state->melee_mobs.type_id[level][index];
612616

@@ -729,13 +733,13 @@ static inline void craftax_update_mobs_move_passive(
729733
CraftaxThreefryKey* rng,
730734
int32_t index
731735
) {
732-
int32_t level = craftax_step_jax_index(
733-
state->player_level,
734-
CRAFTAX_NUM_LEVELS
735-
);
736+
int32_t level = state->player_level;
737+
bool old_mask = state->passive_mobs.mask[level][index];
738+
#ifndef CRAFTAX_JAX_PARITY
739+
if (!old_mask) return;
740+
#endif
736741
int32_t old_row = state->passive_mobs.position[level][index][0];
737742
int32_t old_col = state->passive_mobs.position[level][index][1];
738-
bool old_mask = state->passive_mobs.mask[level][index];
739743
int32_t mob_type = state->passive_mobs.type_id[level][index];
740744

741745
CraftaxThreefryKey draw_key =
@@ -794,13 +798,13 @@ static inline void craftax_update_mobs_move_ranged(
794798
CraftaxThreefryKey* rng,
795799
int32_t index
796800
) {
797-
int32_t level = craftax_step_jax_index(
798-
state->player_level,
799-
CRAFTAX_NUM_LEVELS
800-
);
801+
int32_t level = state->player_level;
802+
bool old_mask = state->ranged_mobs.mask[level][index];
803+
#ifndef CRAFTAX_JAX_PARITY
804+
if (!old_mask) return;
805+
#endif
801806
int32_t old_row = state->ranged_mobs.position[level][index][0];
802807
int32_t old_col = state->ranged_mobs.position[level][index][1];
803-
bool old_mask = state->ranged_mobs.mask[level][index];
804808
int32_t old_cooldown = state->ranged_mobs.attack_cooldown[level][index];
805809
int32_t mob_type = state->ranged_mobs.type_id[level][index];
806810

@@ -932,17 +936,17 @@ static inline void craftax_update_mobs_move_mob_projectile(
932936
CraftaxState* state,
933937
int32_t index
934938
) {
935-
int32_t level = craftax_step_jax_index(
936-
state->player_level,
937-
CRAFTAX_NUM_LEVELS
938-
);
939+
int32_t level = state->player_level;
940+
bool old_mask = state->mob_projectiles.mask[level][index];
941+
#ifndef CRAFTAX_JAX_PARITY
942+
if (!old_mask) return;
943+
#endif
939944
int32_t old_row = state->mob_projectiles.position[level][index][0];
940945
int32_t old_col = state->mob_projectiles.position[level][index][1];
941946
int32_t proposed_row =
942947
old_row + state->mob_projectile_directions[level][index][0];
943948
int32_t proposed_col =
944949
old_col + state->mob_projectile_directions[level][index][1];
945-
bool old_mask = state->mob_projectiles.mask[level][index];
946950

947951
bool proposed_in_player =
948952
proposed_row == state->player_position[0]
@@ -1009,17 +1013,17 @@ static inline void craftax_update_mobs_move_player_projectile(
10091013
CraftaxState* state,
10101014
int32_t index
10111015
) {
1012-
int32_t level = craftax_step_jax_index(
1013-
state->player_level,
1014-
CRAFTAX_NUM_LEVELS
1015-
);
1016+
int32_t level = state->player_level;
1017+
bool old_mask = state->player_projectiles.mask[level][index];
1018+
#ifndef CRAFTAX_JAX_PARITY
1019+
if (!old_mask) return;
1020+
#endif
10161021
int32_t old_row = state->player_projectiles.position[level][index][0];
10171022
int32_t old_col = state->player_projectiles.position[level][index][1];
10181023
int32_t proposed_row =
10191024
old_row + state->player_projectile_directions[level][index][0];
10201025
int32_t proposed_col =
10211026
old_col + state->player_projectile_directions[level][index][1];
1022-
bool old_mask = state->player_projectiles.mask[level][index];
10231027

10241028
float damage_vector[3];
10251029
craftax_update_mobs_player_projectile_damage_vector(

0 commit comments

Comments
 (0)