Skip to content

Commit 0af4221

Browse files
Fix mobilenet input download in ci (pytorch#19001)
Downloading the image is flakey and Im not sure its actually used for anything so removing.
1 parent 1460e29 commit 0af4221

2 files changed

Lines changed: 3 additions & 35 deletions

File tree

examples/models/mobilenet_v2/model.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616

1717
class MV2Model(EagerModelBase):
18-
def __init__(self, use_real_input=True):
19-
self.use_real_input = use_real_input
18+
def __init__(self):
2019
pass
2120

2221
def get_eager_model(self) -> torch.nn.Module:
@@ -26,38 +25,7 @@ def get_eager_model(self) -> torch.nn.Module:
2625
return mv2
2726

2827
def get_example_inputs(self):
29-
tensor_size = (1, 3, 224, 224)
30-
input_batch = (torch.randn(tensor_size),)
31-
if self.use_real_input:
32-
logging.info("Loaded real input image dog.jpg")
33-
import urllib
34-
35-
url, filename = (
36-
"https://github.com/pytorch/hub/raw/master/images/dog.jpg",
37-
"dog.jpg",
38-
)
39-
try:
40-
urllib.URLopener().retrieve(url, filename)
41-
except:
42-
urllib.request.urlretrieve(url, filename)
43-
from PIL import Image
44-
from torchvision import transforms
45-
46-
input_image = Image.open(filename)
47-
preprocess = transforms.Compose(
48-
[
49-
transforms.Resize(256),
50-
transforms.CenterCrop(224),
51-
transforms.ToTensor(),
52-
transforms.Normalize(
53-
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
54-
),
55-
]
56-
)
57-
input_tensor = preprocess(input_image)
58-
input_batch = input_tensor.unsqueeze(0)
59-
input_batch = (input_batch,)
60-
return input_batch
28+
return (torch.randn(1, 3, 224, 224),)
6129

6230

6331
class MV2UntrainedModel(EagerModelBase):

examples/samsung/scripts/mobilenet_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def get_data_loader():
128128

129129
# build pte
130130
pte_filename = "mobilenetV2_enn"
131-
instance = MV2Model(False)
131+
instance = MV2Model()
132132
model = MV2Model().get_eager_model().eval()
133133
assert args.calibration_number
134134
if args.dataset:

0 commit comments

Comments
 (0)