Skip to content

Commit dce1126

Browse files
committed
Validate oneDNN scalar reduction input
1 parent 1b1b683 commit dce1126

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/layers_oneDNN/ReduceLayer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ void ReduceLayerOneDnn::validate_input(const std::vector<Tensor>& input) {
164164
}
165165

166166
const auto& shape = input[0].get_shape();
167-
if (shape.dims() == 0) {
167+
if (shape.dims() == 0 || (shape.dims() == 1 && shape.count() == 1)) {
168168
throw std::runtime_error("ReduceLayerOneDnn: Scalar input not supported");
169169
}
170170
}

0 commit comments

Comments
 (0)