Skip to content

Commit d998c33

Browse files
ammended the example (#12)
1 parent 48dbbbe commit d998c33

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

Quick_Deploy/PyTorch/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ This README showcases how to deploy a simple ResNet model on Triton Inference Se
3333

3434
## Step 1: Export the model
3535

36-
Save the PyTorch model.
36+
Save the PyTorch model. This model needs to be traced/scripted to obtain a torchscript model.
3737

3838
```
3939
# <xx.xx> is the yy:mm for the publishing tag for NVIDIA's PyTorch

Quick_Deploy/PyTorch/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import torch
28-
import torch_tensorrt
2928
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
3029

3130
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")
32-
torch.save(model, "model.pt")
31+
traced_model = torch.jit.trace(model, torch.randn(1,3,224,224).to("cuda"))
32+
torch.jit.save(traced_model, "model.pt")

0 commit comments

Comments
 (0)