Skip to content

[HIP/ROCm] BF16 inference fails: layer_norm kernel not registered for bfloat16 + conv2d fuse passes crash #78759

@oldzhu

Description

@oldzhu

Problem

When running PaddleOCR-VL-1.5 in BF16 mode on AMD ROCm (gfx1100 / ROCm 7.2.0), two errors occur:

Error 1: conv2d fuse passes crash on ROCm

Cannot find the kernel for [FusedConv2dAddAct] on GPU with float32

The conv2d_add_act_fuse_pass and conv2d_add_fuse_pass generate FusedConv2dAddActOp nodes, but this op is only compiled under PADDLE_WITH_CUDA (not HIP). PaddleX worked around this by calling config.delete_pass() in every inference session.

Error 2: layer_norm kernel not registered for bfloat16 on HIP

NotFound: Cannot find the kernel for [layer_norm] on GPU with bfloat16.

The HIP PD_REGISTER_KERNEL block in paddle/phi/kernels/gpu/layer_norm_kernel.cu only registers float and float16. The LayerNormKernel implementation uses templated CUDA-compatible intrinsics that compile and run correctly on ROCm, but bfloat16 was never registered.

Reproduction

Hardware: AMD Radeon RX 7900 GRE (gfx1100), ROCm 7.2.0, Python 3.12

import paddle
import paddle.nn as nn

# Triggers Error 2
x = paddle.randn([4, 64]).cast(paddle.bfloat16).cuda()
ln = nn.LayerNorm(64).to(dtype=paddle.bfloat16)
out = ln(x)  # NotFound: Cannot find kernel for layer_norm on GPU with bfloat16

For Error 1, any BF16 inference session using PaddleX static inference with dtype='bfloat16' and a model containing Conv2d will trigger it.

Impact

PaddleOCR-VL-1.5 (SigLIP visual encoder uses BF16 conv + layer_norm) cannot run in BF16 on AMD ROCm without workarounds.

Fix

PR submitted: see linked PR.

  • conv2d_add_act_fuse_pass.cc / conv2d_add_fuse_pass.cc: add #ifdef PADDLE_WITH_HIP guard in InitializePatterns() returning empty pattern set
  • layer_norm_kernel.cu: add phi::bfloat16 to HIP PD_REGISTER_KERNEL

Validation

After fix, on gfx1100 / ROCm 7.2.0:

  • LayerNorm BF16 SNR vs FP32: 44 dB (excellent)
  • PaddleOCR-VL-1.5 full BF16 pipeline: load 14.6s, inference 202.8s, EXIT:0
  • OCR output correct

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions