1515
1616
1717class MV2Model (EagerModelBase ):
18- def __init__ (self , use_real_input = True ):
19- self .use_real_input = use_real_input
18+ def __init__ (self ):
2019 pass
2120
2221 def get_eager_model (self ) -> torch .nn .Module :
@@ -26,38 +25,7 @@ def get_eager_model(self) -> torch.nn.Module:
2625 return mv2
2726
2827 def get_example_inputs (self ):
29- tensor_size = (1 , 3 , 224 , 224 )
30- input_batch = (torch .randn (tensor_size ),)
31- if self .use_real_input :
32- logging .info ("Loaded real input image dog.jpg" )
33- import urllib
34-
35- url , filename = (
36- "https://github.com/pytorch/hub/raw/master/images/dog.jpg" ,
37- "dog.jpg" ,
38- )
39- try :
40- urllib .URLopener ().retrieve (url , filename )
41- except :
42- urllib .request .urlretrieve (url , filename )
43- from PIL import Image
44- from torchvision import transforms
45-
46- input_image = Image .open (filename )
47- preprocess = transforms .Compose (
48- [
49- transforms .Resize (256 ),
50- transforms .CenterCrop (224 ),
51- transforms .ToTensor (),
52- transforms .Normalize (
53- mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]
54- ),
55- ]
56- )
57- input_tensor = preprocess (input_image )
58- input_batch = input_tensor .unsqueeze (0 )
59- input_batch = (input_batch ,)
60- return input_batch
28+ return (torch .randn (1 , 3 , 224 , 224 ),)
6129
6230
6331class MV2UntrainedModel (EagerModelBase ):
0 commit comments