Skip to content

Commit 63d4d9e

Browse files
authored
use torch2.10 to export onnx, and a bit refine of trt10 code (#349)
1 parent 3d9b2cf commit 63d4d9e

4 files changed

Lines changed: 28 additions & 46 deletions

File tree

tensorrt/plugins/kernels.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ int8_t shfl_down_sync_func(int8_t val, uint32_t delta) {
4545

4646
template<typename scalar_t>
4747
__forceinline__ __device__
48-
scalar_t max_pair_shfl_func(scalar_t& val, int32_t& ind, const uint32_t delta) {
48+
void max_pair_shfl_func(scalar_t& val, int32_t& ind, const uint32_t delta) {
4949
scalar_t other_v = shfl_down_sync_func(val, delta);
5050
int32_t other_i = shfl_down_sync_func(ind, delta);
5151

tensorrt/segment.cu

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,6 @@
1919
#include "read_img.hpp"
2020

2121

22-
// using nvinfer1::IHostMemory;
23-
// using nvinfer1::IBuilder;
24-
// using nvinfer1::INetworkDefinition;
25-
// using nvinfer1::ICudaEngine;
26-
// using nvinfer1::IInt8Calibrator;
27-
// using nvinfer1::IBuilderConfig;
28-
// using nvinfer1::IRuntime;
29-
// using nvinfer1::IExecutionContext;
30-
// using nvinfer1::ILogger;
31-
// using nvinfer1::Dims;
32-
// using Severity = nvinfer1::ILogger::Severity;
3322

3423
using std::string;
3524
using std::ios;

tensorrt/trt_dep.cu

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -102,42 +102,33 @@ void SemanticSegmentTrt::parse_to_engine(string onnx_pth,
102102
config->addOptimizationProfile(profile);
103103

104104
config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 1UL << 32);
105-
106-
if (quant == "fp16" or quant == "int8") { // fp16
107-
if (builder->platformHasFastFp16() == false) {
108-
cout << "fp16 is set, but platform does not support, so we ignore this\n";
109-
} else {
110-
config->setFlag(nvinfer1::BuilderFlag::kFP16);
111-
}
112-
}
113-
if (quant == "bf16") { // bf16
105+
config->setBuilderOptimizationLevel(5);
106+
107+
if (quant == "fp16") { // fp16
108+
config->setFlag(nvinfer1::BuilderFlag::kFP16);
109+
} else if (quant == "int8") { // int8
110+
config->setFlag(nvinfer1::BuilderFlag::kFP16);
111+
config->setFlag(nvinfer1::BuilderFlag::kINT8);
112+
} else if (quant == "bf16") { // bf16
114113
config->setFlag(nvinfer1::BuilderFlag::kBF16);
115-
}
116-
if (quant == "fp8") { // fp8
114+
} else if (quant == "fp8") { // fp8
117115
config->setFlag(nvinfer1::BuilderFlag::kFP8);
118116
}
119117

120118
std::unique_ptr<IInt8Calibrator> calibrator;
121119
if (quant == "int8") { // int8
122-
if (builder->platformHasFastInt8() == false) {
123-
cout << "int8 is set, but platform does not support, so we ignore this\n";
124-
} else {
125-
126-
int batchsize = 32;
127-
int n_cal_batches = -1;
128-
string cal_table_name = "calibrate_int8";
129-
130-
Dims indim = network->getInput(0)->getDimensions();
131-
BatchStream calibrationStream(
132-
batchsize, n_cal_batches, indim,
133-
data_root, data_file);
134-
135-
config->setFlag(nvinfer1::BuilderFlag::kINT8);
136-
137-
calibrator.reset(new Int8EntropyCalibrator2<BatchStream>(
138-
calibrationStream, 0, cal_table_name.c_str(), input_name.c_str(), false));
139-
config->setInt8Calibrator(calibrator.get());
140-
}
120+
int batchsize = 32;
121+
int n_cal_batches = -1;
122+
string cal_table_name = "calibrate_int8";
123+
124+
Dims indim = network->getInput(0)->getDimensions();
125+
BatchStream calibrationStream(
126+
batchsize, n_cal_batches, indim,
127+
data_root, data_file);
128+
129+
calibrator.reset(new Int8EntropyCalibrator2<BatchStream>(
130+
calibrationStream, 0, cal_table_name.c_str(), input_name.c_str(), false));
131+
config->setInt8Calibrator(calibrator.get());
141132
}
142133

143134
// output->setType(nvinfer1::DataType::kINT32);

tools/export_onnx.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
sys.path.insert(0, '.')
55

66
import torch
7-
from torch.onnx import OperatorExportTypes
7+
from torch.export import Dim
88

99
from lib.models import model_factory
1010
from configs import set_cfg_from_file
@@ -37,11 +37,13 @@
3737
# dummy_input = torch.randn(1, 3, 1024, 2048)
3838
input_names = ['input_image']
3939
output_names = ['preds',]
40-
dynamic_axes = {'input_image': {0: 'batch'}, 'preds': {0: 'batch'}}
40+
batch_size = Dim('batch', min=1, max=128)
41+
dynamic_shapes = ({0: batch_size},)
4142

4243
torch.onnx.export(net, dummy_input, args.out_pth,
4344
input_names=input_names, output_names=output_names,
45+
do_constant_folding=True,
4446
verbose=False, opset_version=18,
45-
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
46-
dynamic_axes=dynamic_axes)
47+
external_data=False,
48+
dynamic_shapes=dynamic_shapes)
4749

0 commit comments

Comments
 (0)