Skip to content

train_pytorch.py: wandb image logging fails with "Un-supported shape [224, 9, 224]" when images are NHWC #877

@leo038

Description

@leo038
## 环境
- 脚本: `scripts/train_pytorch.py`
- 运行方式: `torchrun --standalone --nnodes=1 --nproc_per_node=4 scripts/train_pytorch.py <config> --exp_name <name>`
- 触发位置: 首次 batch 时向 wandb 记录 sample images 的代码块

## 现象
```text
ValueError: Un-supported shape for image conversion [224, 9, 224]

发生在 wandb.Image(img_concatenated)

原因

  • 统一 dataloader 在 PyTorch 下返回的 observation 图像是 NHWC [B, H, W, C](与 JAX/LeRobot 一致)。
  • 当前实现假定图像为 NCHW,对 img[i](实际为 [H, W, C])做了 permute(1, 2, 0),得到 [224, 3, 224],再沿 dim=1 拼接 3 个视角得到 [224, 9, 224],wandb 无法识别。

建议修复

  1. 根据格式判断:若 frame.shape[0] == 3 视为 NCHW,做 permute(1, 2, 0) 转为 NHWC;否则按已是 NHWC 处理。
  2. 拼接时在 NHWC 的 width 维度上 cat,得到 [H, W_total, 3] 再传给 wandb.Image
  3. 若像素值在 [-1, 1],先线性缩放到 [0, 1] 再传入,避免 wandb 的数值范围警告。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions