Skip to content

fix(saute): coerce tensors to correct device in SauteAdapter#387

Open
96sanjay wants to merge 1 commit into
PKU-Alignment:mainfrom
96sanjay:fix/saute-adapter-device-mismatch
Open

fix(saute): coerce tensors to correct device in SauteAdapter#387
96sanjay wants to merge 1 commit into
PKU-Alignment:mainfrom
96sanjay:fix/saute-adapter-device-mismatch

Conversation

@96sanjay
Copy link
Copy Markdown

@96sanjay 96sanjay commented May 1, 2026

Hit a device mismatch crash running PPOSaute on GPU with a custom env that returns cpu tensors. _safety_obs is on cuda but env outputs stay on cpu, so the in-place ops in
_safety_step fail immediately with RuntimeError: Expected all tensors to be on the same device.

Fix coerces all env outputs to self._device at the top of step() and reset(), plus a guard in _augment_obs for final_observation. No-op if everything is already on the
right device.

Motivation and Context

Affects any custom env that doesn't move its outputs to the training device — pretty common outside of mujoco/safety-gym. Crashes on the very first step.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have read the CONTRIBUTION guide. (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly. (required for a bug fix or a new feature)
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format. (required)
  • I have checked the code using make lint. (required)
  • I have ensured make test pass. (required)

@96sanjay 96sanjay force-pushed the fix/saute-adapter-device-mismatch branch 2 times, most recently from 3e215b6 to 467deed Compare May 1, 2026 15:48
When the environment returns CPU tensors while training on GPU,
SauteAdapter crashes on the first step because _safety_obs lives
on the training device but env outputs (obs, reward, cost, etc.)
are still on CPU.

Coerce all env outputs to self._device at the step() and reset()
boundary. Also guard _augment_obs for final_observation from the
info dict. No-op when tensors are already on the correct device.
@96sanjay 96sanjay force-pushed the fix/saute-adapter-device-mismatch branch from 467deed to 6589a1b Compare May 1, 2026 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant