Skip to content

Commit 5bc09f7

Browse files
committed
test(onnx): smoke test ONNX export with frame-stacking (k=1, k=4)
Verifies the obs_space auto-detect produces an ONNX graph with the right image-channel count for both k=1 (3 channels) and k=4 (12 channels). Without this, a regression in train_v9's export would silently produce ONNX models that can't accept k=4 stacked input on a real KS0223 — the deployment use case.
1 parent 1758e02 commit 5bc09f7

1 file changed

Lines changed: 81 additions & 0 deletions

File tree

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Smoke test for the ONNX export with frame-stacking (k>1).
2+
3+
Verifies that `export_to_onnx_discrete` reads channel count from the
4+
model's saved obs_space (not hard-coded to 3), so a k=4 model exports
5+
to ONNX with the correct (1, 84, 84, 12) image input shape.
6+
"""
7+
from __future__ import annotations
8+
9+
from pathlib import Path
10+
11+
import numpy as np
12+
import gymnasium as gym
13+
import pytest
14+
from gymnasium import spaces
15+
from stable_baselines3 import PPO
16+
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
17+
18+
from training.train_cardboard_corridor_v9 import export_to_onnx_discrete
19+
20+
21+
class _Stub(gym.Env):
22+
"""Tiny Dict-obs env matching the cardboard-corridor obs structure."""
23+
metadata = {"render_modes": []}
24+
25+
def __init__(self):
26+
self.observation_space = spaces.Dict({
27+
"image": spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8),
28+
"ultrasonic": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
29+
})
30+
self.action_space = spaces.Discrete(5)
31+
32+
def reset(self, *, seed=None, options=None):
33+
return self._obs(), {}
34+
35+
def step(self, a):
36+
return self._obs(), 0.0, True, False, {}
37+
38+
def _obs(self):
39+
return {
40+
"image": np.zeros((84, 84, 3), dtype=np.uint8),
41+
"ultrasonic": np.zeros((1,), dtype=np.float32),
42+
}
43+
44+
45+
@pytest.mark.parametrize("n_stack", [1, 4])
46+
def test_onnx_export_handles_frame_stack(n_stack, tmp_path):
47+
"""Round-trip: train tiny PPO, frame-stack k, export ONNX, verify the
48+
ONNX graph's image input has the correct channel count."""
49+
venv = DummyVecEnv([lambda: _Stub()])
50+
if n_stack > 1:
51+
venv = VecFrameStack(venv, n_stack=n_stack, channels_order="last")
52+
model = PPO("MultiInputPolicy", venv, n_steps=8, batch_size=8, device="cpu")
53+
out = tmp_path / f"stub-fs{n_stack}.onnx"
54+
55+
export_to_onnx_discrete(model, out)
56+
venv.close()
57+
58+
assert out.exists(), "ONNX file not created"
59+
60+
# Read the ONNX graph and check input shapes match k.
61+
import onnx
62+
graph = onnx.load(str(out)).graph
63+
image_input = next(i for i in graph.input if i.name == "image")
64+
image_shape = [
65+
dim.dim_value if dim.HasField("dim_value") else -1
66+
for dim in image_input.type.tensor_type.shape.dim
67+
]
68+
# Expected (batch, 84, 84, n_stack*3) — first dim is dynamic per
69+
# dynamic_axes={'image': {0: 'batch'}}, so it shows up as 0 or -1.
70+
assert image_shape[1:] == [84, 84, n_stack * 3], (
71+
f"Expected ONNX image input HWC with C={n_stack * 3}; got {image_shape}"
72+
)
73+
74+
ultra_input = next(i for i in graph.input if i.name == "ultrasonic")
75+
ultra_shape = [
76+
dim.dim_value if dim.HasField("dim_value") else -1
77+
for dim in ultra_input.type.tensor_type.shape.dim
78+
]
79+
assert ultra_shape[1] == n_stack, (
80+
f"Expected ONNX ultrasonic input dim 1 = {n_stack}; got {ultra_shape}"
81+
)

0 commit comments

Comments
 (0)