Skip to content

Commit c575973

Browse files
authored
[Relax][Onnx][BatchNorm] Pass momentum and training_mode into BatchNorm Operator (#18704)
### Description - Onnx model have training_mode atrr = False, but Relax model after convert have training = True - Momentum values ​​in Relax module are not the same as onnx model ### Steps to Reproduce <img width="600" height="400" alt="BatchNorm" src="https://github.com/user-attachments/assets/2f0ca26b-e83b-4ab8-ab06-a537802af6de" /> - Relax model: ``` class Module: def main(X: R.Tensor((2, 3, 4, 4), dtype="float32")) -> R.Tensor((2, 3, 4, 4), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): lv: R.Tuple(R.Tensor((2, 3, 4, 4), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.nn.batch_norm(X, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3], axis=1, epsilon=9.9999997473787516e-06, center=True, scale=True, momentum=0.10000000000000001, training=True) lv1: R.Tensor((2, 3, 4, 4), dtype="float32") = lv[0] lv2: R.Tensor((3,), dtype="float32") = lv[1] lv3: R.Tensor((3,), dtype="float32") = lv[2] gv: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1 R.output(gv) return gv ``` ### Resolved - Get Attributes and Pass momentum/training_mode with default value into BatchNorm Operator - Fixed: #18703
1 parent 4b2b639 commit c575973

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2435,8 +2435,18 @@ def _impl_v15(cls, bb, inputs, attr, params):
24352435
mean = inputs[3]
24362436
var = inputs[4]
24372437
epsilon = attr.get("epsilon", 1e-05)
2438+
momentum = attr.get("momentum", 0.9)
2439+
training_mode = attr.get("training_mode", 0)
24382440
return relax.op.nn.batch_norm(
2439-
data, gamma=scale, beta=bias, moving_mean=mean, moving_var=var, epsilon=epsilon, axis=1
2441+
data,
2442+
gamma=scale,
2443+
beta=bias,
2444+
moving_mean=mean,
2445+
moving_var=var,
2446+
axis=1,
2447+
epsilon=epsilon,
2448+
momentum=momentum,
2449+
training=training_mode,
24402450
)
24412451

24422452

0 commit comments

Comments
 (0)