Skip to content

Commit 3705f9b

Browse files
committed
Fix rendering with non-default width/height
1 parent d63ca97 commit 3705f9b

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

crazyflow/sim/sim.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,24 @@ def thrust_control(self, cmd: Array):
306306
self.data = self.data.replace(controls=self.data.controls.replace(thrust=controls))
307307

308308
def render(
309-
self, mode: str | None = "human", world: int = 0, default_cam_config: dict | None = None
309+
self,
310+
mode: str | None = "human",
311+
world: int = 0,
312+
default_cam_config: dict | None = None,
313+
width: int = 640,
314+
height: int = 480,
310315
) -> NDArray | None:
311316
if self.viewer is None:
312317
patch_viewer()
318+
self.mj_model.vis.global_.offwidth = width
319+
self.mj_model.vis.global_.offheight = height
313320
self.viewer = MujocoRenderer(
314321
self.mj_model,
315322
self.mj_data,
316323
max_geom=self.max_visual_geom,
317324
default_cam_config=default_cam_config,
318-
height=480,
319-
width=640,
325+
height=height,
326+
width=width,
320327
)
321328
self.mj_data.qpos[:] = self.data.mjx_data.qpos[world, :]
322329
self.mj_data.mocap_pos[:] = self.data.mjx_data.mocap_pos[world, :]

tests/unit/test_sim.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,12 @@ def test_render_human(device: str):
270270
def test_render_rgb_array(device: str):
271271
skip_unavailable_device(device)
272272
sim = Sim(n_worlds=2, device=device)
273-
img = sim.render(mode="rgb_array")
273+
img = sim.render(mode="rgb_array", width=1024, height=1024)
274274
assert isinstance(img, np.ndarray), "Image must be a numpy array"
275-
assert img.ndim == 3, "Image must be 3D"
275+
assert img.shape == (1024, 1024, 3), f"Unexpected image shape {img.shape}"
276+
# Check if mj_model.vis.global_.offwidth is set correctly
277+
assert not all(img[0, 0, :] == 0), "Image contains black patches"
278+
assert not all(img[-1, -1, :] == 0), "Image contains black patches"
276279

277280

278281
@pytest.mark.unit

0 commit comments

Comments
 (0)