Skip to content

BF16 load_weights doesn't reach inference stream — eval produces degenerate action distribution #534

@Infatoshi

Description

@Infatoshi

Summary

After training an env in the default BF16 build and saving a checkpoint, puffer eval --load-model-path <ckpt> reports Loaded weights from … but the policy's action distribution is degenerate — it collapses to a single action (0/NOOP) regardless of observation. Rebuilding with --float (FP32) and loading the same on-disk checkpoint produces the expected trained behavior.

No error is raised; the failure is silent, so it's easy to mistake for "the policy didn't learn."

Repro

Reproduced on the ocean/space_invaders env I just wrote (but the bug is unrelated to the env — any env trained in BF16 should hit this):

# 1. Train in BF16 (default)
./build.sh <env>
puffer train <env> --train.total-timesteps 400000000
# wait for checkpoint

# 2. Eval in BF16 — policy appears untrained
puffer eval <env> --load-model-path latest
# Observed: ship stationary, score stays near 0, all actions = NOOP

# 3. Rebuild for FP32 and eval the SAME on-disk checkpoint
./build.sh <env> --float
puffer eval <env> --load-model-path latest
# Observed: policy plays normally (in my case score ~950, perf ~0.78)

Instrumenting c_step to log env->actions[0] confirms:

  • BF16 build after load: actions are all 0 across many envs over many steps
  • FP32 build after load: actions are varied (0, 1, 2, 3), player moves and fires

Expected

BF16 and FP32 builds should both produce the trained policy when loading the same checkpoint (up to BF16 rounding).

Hypothesis

I haven't fully verified the root cause, but src/bindings.cu::load_weights looks suspicious:

cudaMemcpy(pufferl.master_weights.data, buf.data(), nbytes, cudaMemcpyHostToDevice);
if (USE_BF16) {
    int n = numel(pufferl.param_puf.shape);
    cast<<<grid_size(n), BLOCK_SIZE, 0, pufferl.default_stream>>>(
        pufferl.param_puf.data, pufferl.master_weights.data, n);
}

The cast is launched on pufferl.default_stream but the forward pass runs on per-buffer streams (vec->streams[buf]). CUDA streams are independent without explicit sync, so the first forward after load may still see the pre-load contents of param_puf. After that first forward writes out actions (all zeros because the model hasn't been initialized in a meaningful way on that param block), subsequent forwards may or may not pick up the cast result depending on scheduling — in practice I see the all-NOOP behavior persist across the whole eval session.

Possible fixes (the maintainers can pick):

  • cudaStreamSynchronize(pufferl.default_stream) at end of load_weights
  • Issue the cast on each buffer stream, or use an event the per-buffer streams wait on
  • Cast on cudaStreamPerThread / a stream that's guaranteed ordered with all consumers

Happy to test any patch you want; I don't want to PR a guess.

Environment

  • GPU: RTX PRO 6000 Blackwell (sm_120), driver 595.58.03, CUDA 13.2 (nvcc 12.8)
  • Kernel 6.17, Ubuntu 24.04
  • pufferlib on main as of this afternoon
  • Default build (BF16); compared against --float

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions