Skip to content

Handle rollout MuJoCo errors as warnings#3247

Open
devshahofficial wants to merge 4 commits into
google-deepmind:mainfrom
devshahofficial:codex/rollout-handle-errors
Open

Handle rollout MuJoCo errors as warnings#3247
devshahofficial wants to merge 4 commits into
google-deepmind:mainfrom
devshahofficial:codex/rollout-handle-errors

Conversation

@devshahofficial

Copy link
Copy Markdown
Contributor

Summary

  • add raise_on_error to rollout so existing strict error behavior remains the default
  • when raise_on_error=False, intercept MuJoCo fatal errors during a rollout, fill the failed trajectory remainder from the current state and sensor data, and continue other batches
  • document the new tolerant path and cover both single-threaded and threaded rollout behavior

Fixes #2704

Validation

  • python3 -m py_compile python/mujoco/rollout.py python/mujoco/rollout_test.py
  • built local MuJoCo, generated the Python sdist, built and installed a wheel from this checkout
  • MUJOCO_GL=disable /tmp/mujoco-rollout-venv/bin/python -m pytest -q --pyargs mujoco.rollout_test (67 passed, 1 skipped)

@devshahofficial devshahofficial changed the title [codex] Handle rollout MuJoCo errors as warnings Handle rollout MuJoCo errors as warnings Apr 28, 2026
@devshahofficial devshahofficial marked this pull request as ready for review April 30, 2026 15:12
@devshahofficial

Copy link
Copy Markdown
Contributor Author

@haroonq @yuvaltassa this is now marked ready for review. Would one of you, or another active maintainer/contributor, be able to take a look when you have bandwidth? Happy to follow up on any requested changes or adjust the approach if there is a better direction.

@yuvaltassa yuvaltassa self-assigned this May 18, 2026
@yuvaltassa

Copy link
Copy Markdown
Collaborator

@aftersomemath WDYT?

@aftersomemath aftersomemath left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2704 makes a good point, that exceptions should trigger the same behavior as warnings. I think we can simplify this PR considerably, see detailed comments.

Comment thread python/mujoco/rollout_test.py Outdated
model, mujoco.MjData(model), initial_state[1], ctrl[1]
)
np.testing.assert_array_equal(
state[0], np.tile(state[0, 0], (nstep, 1))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test would be stronger if the exception occurred in the middle of a rollout instead of the first step. If this can be done, I recommend it.

Comment thread python/mujoco/rollout.py
state: Optional[npt.ArrayLike] = None,
sensordata: Optional[npt.ArrayLike] = None,
chunk_size: Optional[int] = None,
raise_on_error: bool = True,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that rollout does not raise any errors when this flag is true. It just fills the buffer with the last state. So a better name is probably fill_on_true.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is correct as is, my previous comment was incorrect.

Comment thread python/mujoco/rollout.cc Outdated
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
this->pool_.get(), chunk_size_final);
if (raise_on_error) {
InterceptMjErrors(_unsafe_rollout_threaded)(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that there is no need to use InterceptMjErrors on _unsafe_rollout_threaded because _unsafe_rollout_threaded is already catching all the MuJoCo exceptions, and the return value of InterceptMjErrors is not used.

Comment thread python/mujoco/rollout.cc Outdated
model_ptrs, data_ptrs[0], 0, nbatch, nstep, control_spec,
state0_ptr, warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
if (raise_on_error) {
InterceptMjErrors(_unsafe_rollout)(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Comment thread python/mujoco/rollout.cc Outdated
mj_setState(m[r], d, state0 + r*nstate, mjSTATE_FULLPHYSICS);
if (handle_mj_errors) {
if (!TryMjCall([&]() {
mj_setState(m[r], d, state0 + r*nstate, mjSTATE_FULLPHYSICS);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary to catch exception in mj_setState. If there is an exception, it is probably our fault as the only thing the call's success depends on is the dimensions, which were already validated. Also, it makes the code verbose.

Comment thread python/mujoco/rollout.cc Outdated
mj_setState(m[r], d, control + step*ncontrol, control_spec);
if (handle_mj_errors) {
if (!TryMjCall([&]() {
mj_setState(m[r], d, control + step*ncontrol, control_spec);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Comment thread python/mujoco/rollout.cc Outdated
// step
mj_step(m[r], d);
if (handle_mj_errors) {
if (!TryMjCall([&]() { mj_step(m[r], d); })) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this call to mj_step is the only place where exceptions need to be caught. If so, then TryMjCall can be inlined, and the FillRemainingOutputs function can be removed.

Instead of FillRemainingOutputs just set a new flag (maybe nerror?) and trigger the existing block of code for filling the buffer.

Comment thread python/mujoco/rollout.cc Outdated
mju_copy(sensordata + step*nsensordata, d->sensordata, nsensordata);
}
}
FillRemainingOutputs(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend keeping this code as is, and add a flag nerror to the if statement.

@devshahofficial

Copy link
Copy Markdown
Contributor Author

@aftersomemath @yuvaltassa thanks, this was a helpful pass. I pushed a cleanup commit and the rollout code is a lot closer to the shape suggested in the review now.

Main changes:

  • errors now go through the same fill-the-rest path as warnings, using an nerror flag
  • removed the extra fill helper/template wrapper and stopped catching mj_setState
  • removed the outer InterceptMjErrors calls around the rollout wrappers
  • tightened the regression test so the error happens mid-rollout instead of on the first step, with both main-thread and threaded coverage

I kept the public kwarg as raise_on_error for now because the default path still raises, while False opts into the fill behavior. Happy to rename it if you prefer something like fill_on_error.

Validated locally with:

  • python3 -m py_compile python/mujoco/rollout.py python/mujoco/rollout_test.py
  • built the Python sdist and wheel from this branch
  • MUJOCO_GL=disable /private/tmp/mujoco-rollout-fix-venv/bin/python -m pytest -q --pyargs mujoco.rollout_test (67 passed, 1 skipped)

@aftersomemath aftersomemath left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes! I have to request a few more to further simplify things, but this looks almost done.

Comment thread python/mujoco/rollout.cc Outdated
}

// if any warnings or handled errors, fill remaining outputs with current outputs, break
if (nwarning || nerror) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block of code that fills the buffer should remain at the beginning of the function.

Comment thread python/mujoco/rollout.cc Outdated

// copy out new state
if (state) {
if (state && !nwarning && !nerror) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These additional checks will not be needed when the fill code block is moved back to the beginning.

Comment thread python/mujoco/rollout.cc Outdated
if (raise_on_error) {
throw;
}
nerror = true;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the continue keyword here so that the remaining code is skipped, and the fill code will execute.

Comment thread python/mujoco/rollout.py
state: Optional[npt.ArrayLike] = None,
sensordata: Optional[npt.ArrayLike] = None,
chunk_size: Optional[int] = None,
raise_on_error: bool = True,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is correct as is, my previous comment was incorrect.

Comment thread python/mujoco/rollout.cc
} // namespace

} // namespace mujoco::python

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: whitespace change should be removed

Comment thread python/mujoco/rollout.cc Outdated
chunk_size_final = *chunk_size;
}
InterceptMjErrors(_unsafe_rollout_threaded)(
_unsafe_rollout_threaded(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, we should still use InterceptMjErrors here. In the original PR I thought the usage was InterceptMjErrors was newly introduced, but it was already there.

Comment thread python/mujoco/rollout.cc Outdated
this->pool_.get(), chunk_size_final, raise_on_error);
} else {
InterceptMjErrors(_unsafe_rollout)(
_unsafe_rollout(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment thread python/mujoco/rollout.cc Outdated
int id = pool->WorkerId();
_unsafe_rollout(m, d[id], j*chunk_size, (j+1)*chunk_size,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
try {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we do not need this new code if we go back to using InterceptMjErrors with _unsafe_rollout_threaded.

Comment thread python/mujoco/rollout.cc Outdated
nstep, control_spec, state0, warmstart0, control, state, sensordata);
auto task = [=, &m, &d, &rollout_error, &rollout_error_mutex](void) {
try {
_unsafe_rollout(m, d[pool->WorkerId()], nfulljobs*chunk_size,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@devshahofficial

Copy link
Copy Markdown
Contributor Author

@aftersomemath @yuvaltassa thanks for the review pass. I pushed 8f482ac4 to address the latest rollout threads:

  • moved the fill-remaining-outputs block back to the beginning of the per-step loop
  • use continue after a handled mj_step error so the existing fill path handles the rest of the trajectory
  • removed the extra nwarning / nerror checks that became unnecessary once the fill block moved back up
  • restored InterceptMjErrors around both rollout wrapper calls
  • removed the extra threaded exception-forwarding code and unused headers
  • restored the whitespace-only EOF change

I kept raise_on_error unchanged since the follow-up said the name is correct. Local validation: Python syntax checks and git diff origin/main --check for the rollout files. I could not run the compiled rollout tests in this environment because cmake is not installed here.

@aftersomemath aftersomemath left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me except for a minor comment.

With regards to the comment about not running the new changes, it is important to run and test changes yourself before submitting them for review. I see most of the CI jobs passed, except for one on Mac OS which is complaining about the rollout.cc file. Not sure if that is related.

Comment thread python/mujoco/rollout.cc Outdated
mju_copy(sensordata + step*nsensordata, d->sensordata, nsensordata);
}
nerror = true;
continue;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the code is simplified, it seems clear that the continue statement and the mj_getState and mju_copy calls can be removed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aftersomemath thanks, agreed. I pushed 7dffd601 to remove the redundant handled-error copy-out and continue. After nerror is set, the existing copy-out path records the current state/sensordata and the next loop iteration fills the remaining outputs. Local validation: python3 -m py_compile python/mujoco/rollout.py python/mujoco/rollout_test.py, git diff --check -- python/mujoco/rollout.cc python/mujoco/rollout.py python/mujoco/rollout_test.py doc/python.rst, and the same whitespace check against origin/main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[rollout] Handle exceptions as warnings

3 participants