diff --git a/thop/vision/onnx_counter.py b/thop/vision/onnx_counter.py index 8beb39a..d36da69 100644 --- a/thop/vision/onnx_counter.py +++ b/thop/vision/onnx_counter.py @@ -304,7 +304,7 @@ def onnx_counter_globalaveragepool(diction, node): macs = calculate_zero_ops() output_name = node.output[0] input_size = diction[node.input[0]] - output_size = input_size + output_size = np.append(input_size[:2], np.ones(len(input_size) - 2, input_size.dtype)) return macs, output_size, output_name