|
| 1 | +"""Smoke test for the ONNX export with frame-stacking (k>1). |
| 2 | +
|
| 3 | +Verifies that `export_to_onnx_discrete` reads channel count from the |
| 4 | +model's saved obs_space (not hard-coded to 3), so a k=4 model exports |
| 5 | +to ONNX with the correct (1, 84, 84, 12) image input shape. |
| 6 | +""" |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +from pathlib import Path |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import gymnasium as gym |
| 13 | +import pytest |
| 14 | +from gymnasium import spaces |
| 15 | +from stable_baselines3 import PPO |
| 16 | +from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack |
| 17 | + |
| 18 | +from training.train_cardboard_corridor_v9 import export_to_onnx_discrete |
| 19 | + |
| 20 | + |
| 21 | +class _Stub(gym.Env): |
| 22 | + """Tiny Dict-obs env matching the cardboard-corridor obs structure.""" |
| 23 | + metadata = {"render_modes": []} |
| 24 | + |
| 25 | + def __init__(self): |
| 26 | + self.observation_space = spaces.Dict({ |
| 27 | + "image": spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8), |
| 28 | + "ultrasonic": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), |
| 29 | + }) |
| 30 | + self.action_space = spaces.Discrete(5) |
| 31 | + |
| 32 | + def reset(self, *, seed=None, options=None): |
| 33 | + return self._obs(), {} |
| 34 | + |
| 35 | + def step(self, a): |
| 36 | + return self._obs(), 0.0, True, False, {} |
| 37 | + |
| 38 | + def _obs(self): |
| 39 | + return { |
| 40 | + "image": np.zeros((84, 84, 3), dtype=np.uint8), |
| 41 | + "ultrasonic": np.zeros((1,), dtype=np.float32), |
| 42 | + } |
| 43 | + |
| 44 | + |
| 45 | +@pytest.mark.parametrize("n_stack", [1, 4]) |
| 46 | +def test_onnx_export_handles_frame_stack(n_stack, tmp_path): |
| 47 | + """Round-trip: train tiny PPO, frame-stack k, export ONNX, verify the |
| 48 | + ONNX graph's image input has the correct channel count.""" |
| 49 | + venv = DummyVecEnv([lambda: _Stub()]) |
| 50 | + if n_stack > 1: |
| 51 | + venv = VecFrameStack(venv, n_stack=n_stack, channels_order="last") |
| 52 | + model = PPO("MultiInputPolicy", venv, n_steps=8, batch_size=8, device="cpu") |
| 53 | + out = tmp_path / f"stub-fs{n_stack}.onnx" |
| 54 | + |
| 55 | + export_to_onnx_discrete(model, out) |
| 56 | + venv.close() |
| 57 | + |
| 58 | + assert out.exists(), "ONNX file not created" |
| 59 | + |
| 60 | + # Read the ONNX graph and check input shapes match k. |
| 61 | + import onnx |
| 62 | + graph = onnx.load(str(out)).graph |
| 63 | + image_input = next(i for i in graph.input if i.name == "image") |
| 64 | + image_shape = [ |
| 65 | + dim.dim_value if dim.HasField("dim_value") else -1 |
| 66 | + for dim in image_input.type.tensor_type.shape.dim |
| 67 | + ] |
| 68 | + # Expected (batch, 84, 84, n_stack*3) — first dim is dynamic per |
| 69 | + # dynamic_axes={'image': {0: 'batch'}}, so it shows up as 0 or -1. |
| 70 | + assert image_shape[1:] == [84, 84, n_stack * 3], ( |
| 71 | + f"Expected ONNX image input HWC with C={n_stack * 3}; got {image_shape}" |
| 72 | + ) |
| 73 | + |
| 74 | + ultra_input = next(i for i in graph.input if i.name == "ultrasonic") |
| 75 | + ultra_shape = [ |
| 76 | + dim.dim_value if dim.HasField("dim_value") else -1 |
| 77 | + for dim in ultra_input.type.tensor_type.shape.dim |
| 78 | + ] |
| 79 | + assert ultra_shape[1] == n_stack, ( |
| 80 | + f"Expected ONNX ultrasonic input dim 1 = {n_stack}; got {ultra_shape}" |
| 81 | + ) |
0 commit comments