Handle rollout MuJoCo errors as warnings#3247
Conversation
|
@haroonq @yuvaltassa this is now marked ready for review. Would one of you, or another active maintainer/contributor, be able to take a look when you have bandwidth? Happy to follow up on any requested changes or adjust the approach if there is a better direction. |
|
@aftersomemath WDYT? |
aftersomemath
left a comment
There was a problem hiding this comment.
#2704 makes a good point, that exceptions should trigger the same behavior as warnings. I think we can simplify this PR considerably, see detailed comments.
| model, mujoco.MjData(model), initial_state[1], ctrl[1] | ||
| ) | ||
| np.testing.assert_array_equal( | ||
| state[0], np.tile(state[0, 0], (nstep, 1)) |
There was a problem hiding this comment.
The test would be stronger if the exception occurred in the middle of a rollout instead of the first step. If this can be done, I recommend it.
| state: Optional[npt.ArrayLike] = None, | ||
| sensordata: Optional[npt.ArrayLike] = None, | ||
| chunk_size: Optional[int] = None, | ||
| raise_on_error: bool = True, |
There was a problem hiding this comment.
It seems that rollout does not raise any errors when this flag is true. It just fills the buffer with the last state. So a better name is probably fill_on_true.
There was a problem hiding this comment.
The name is correct as is, my previous comment was incorrect.
| warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr, | ||
| this->pool_.get(), chunk_size_final); | ||
| if (raise_on_error) { | ||
| InterceptMjErrors(_unsafe_rollout_threaded)( |
There was a problem hiding this comment.
It seems that there is no need to use InterceptMjErrors on _unsafe_rollout_threaded because _unsafe_rollout_threaded is already catching all the MuJoCo exceptions, and the return value of InterceptMjErrors is not used.
| model_ptrs, data_ptrs[0], 0, nbatch, nstep, control_spec, | ||
| state0_ptr, warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr); | ||
| if (raise_on_error) { | ||
| InterceptMjErrors(_unsafe_rollout)( |
| mj_setState(m[r], d, state0 + r*nstate, mjSTATE_FULLPHYSICS); | ||
| if (handle_mj_errors) { | ||
| if (!TryMjCall([&]() { | ||
| mj_setState(m[r], d, state0 + r*nstate, mjSTATE_FULLPHYSICS); |
There was a problem hiding this comment.
I don't think it is necessary to catch exception in mj_setState. If there is an exception, it is probably our fault as the only thing the call's success depends on is the dimensions, which were already validated. Also, it makes the code verbose.
| mj_setState(m[r], d, control + step*ncontrol, control_spec); | ||
| if (handle_mj_errors) { | ||
| if (!TryMjCall([&]() { | ||
| mj_setState(m[r], d, control + step*ncontrol, control_spec); |
| // step | ||
| mj_step(m[r], d); | ||
| if (handle_mj_errors) { | ||
| if (!TryMjCall([&]() { mj_step(m[r], d); })) { |
There was a problem hiding this comment.
I believe this call to mj_step is the only place where exceptions need to be caught. If so, then TryMjCall can be inlined, and the FillRemainingOutputs function can be removed.
Instead of FillRemainingOutputs just set a new flag (maybe nerror?) and trigger the existing block of code for filling the buffer.
| mju_copy(sensordata + step*nsensordata, d->sensordata, nsensordata); | ||
| } | ||
| } | ||
| FillRemainingOutputs( |
There was a problem hiding this comment.
I recommend keeping this code as is, and add a flag nerror to the if statement.
|
@aftersomemath @yuvaltassa thanks, this was a helpful pass. I pushed a cleanup commit and the rollout code is a lot closer to the shape suggested in the review now. Main changes:
I kept the public kwarg as Validated locally with:
|
aftersomemath
left a comment
There was a problem hiding this comment.
Thanks for making the changes! I have to request a few more to further simplify things, but this looks almost done.
| } | ||
|
|
||
| // if any warnings or handled errors, fill remaining outputs with current outputs, break | ||
| if (nwarning || nerror) { |
There was a problem hiding this comment.
This block of code that fills the buffer should remain at the beginning of the function.
|
|
||
| // copy out new state | ||
| if (state) { | ||
| if (state && !nwarning && !nerror) { |
There was a problem hiding this comment.
These additional checks will not be needed when the fill code block is moved back to the beginning.
| if (raise_on_error) { | ||
| throw; | ||
| } | ||
| nerror = true; |
There was a problem hiding this comment.
You can use the continue keyword here so that the remaining code is skipped, and the fill code will execute.
| state: Optional[npt.ArrayLike] = None, | ||
| sensordata: Optional[npt.ArrayLike] = None, | ||
| chunk_size: Optional[int] = None, | ||
| raise_on_error: bool = True, |
There was a problem hiding this comment.
The name is correct as is, my previous comment was incorrect.
| } // namespace | ||
|
|
||
| } // namespace mujoco::python | ||
|
|
There was a problem hiding this comment.
nit: whitespace change should be removed
| chunk_size_final = *chunk_size; | ||
| } | ||
| InterceptMjErrors(_unsafe_rollout_threaded)( | ||
| _unsafe_rollout_threaded( |
There was a problem hiding this comment.
On second thought, we should still use InterceptMjErrors here. In the original PR I thought the usage was InterceptMjErrors was newly introduced, but it was already there.
| this->pool_.get(), chunk_size_final, raise_on_error); | ||
| } else { | ||
| InterceptMjErrors(_unsafe_rollout)( | ||
| _unsafe_rollout( |
| int id = pool->WorkerId(); | ||
| _unsafe_rollout(m, d[id], j*chunk_size, (j+1)*chunk_size, | ||
| nstep, control_spec, state0, warmstart0, control, state, sensordata); | ||
| try { |
There was a problem hiding this comment.
I believe we do not need this new code if we go back to using InterceptMjErrors with _unsafe_rollout_threaded.
| nstep, control_spec, state0, warmstart0, control, state, sensordata); | ||
| auto task = [=, &m, &d, &rollout_error, &rollout_error_mutex](void) { | ||
| try { | ||
| _unsafe_rollout(m, d[pool->WorkerId()], nfulljobs*chunk_size, |
|
@aftersomemath @yuvaltassa thanks for the review pass. I pushed
I kept |
aftersomemath
left a comment
There was a problem hiding this comment.
Looks good to me except for a minor comment.
With regards to the comment about not running the new changes, it is important to run and test changes yourself before submitting them for review. I see most of the CI jobs passed, except for one on Mac OS which is complaining about the rollout.cc file. Not sure if that is related.
| mju_copy(sensordata + step*nsensordata, d->sensordata, nsensordata); | ||
| } | ||
| nerror = true; | ||
| continue; |
There was a problem hiding this comment.
Now that the code is simplified, it seems clear that the continue statement and the mj_getState and mju_copy calls can be removed.
There was a problem hiding this comment.
@aftersomemath thanks, agreed. I pushed 7dffd601 to remove the redundant handled-error copy-out and continue. After nerror is set, the existing copy-out path records the current state/sensordata and the next loop iteration fills the remaining outputs. Local validation: python3 -m py_compile python/mujoco/rollout.py python/mujoco/rollout_test.py, git diff --check -- python/mujoco/rollout.cc python/mujoco/rollout.py python/mujoco/rollout_test.py doc/python.rst, and the same whitespace check against origin/main.
Summary
raise_on_errorto rollout so existing strict error behavior remains the defaultraise_on_error=False, intercept MuJoCo fatal errors during a rollout, fill the failed trajectory remainder from the current state and sensor data, and continue other batchesFixes #2704
Validation
python3 -m py_compile python/mujoco/rollout.py python/mujoco/rollout_test.pyMUJOCO_GL=disable /tmp/mujoco-rollout-venv/bin/python -m pytest -q --pyargs mujoco.rollout_test(67 passed, 1 skipped)