Skip to content

Commit 451b7ec

Browse files
committed
Fix Vision Transformer demo for test compatibility
1 parent fa237c3 commit 451b7ec

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

computer_vision/vision_transformer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@
99
https://huggingface.co/docs/transformers/model_doc/vit
1010
"""
1111

12-
import requests
13-
import torch
14-
from PIL import Image
15-
from transformers import ViTForImageClassification, ViTImageProcessor
16-
1712

1813
def vision_transformer_demo() -> None:
1914
"""
@@ -23,14 +18,32 @@ def vision_transformer_demo() -> None:
2318
>>> vision_transformer_demo() # doctest: +SKIP
2419
Predicted label: tabby, tabby cat
2520
"""
26-
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cat_sample.jpeg"
21+
try:
22+
import requests
23+
import torch
24+
from PIL import Image
25+
from transformers import ViTForImageClassification, ViTImageProcessor
26+
except ImportError as e:
27+
raise ImportError(
28+
"This demo requires 'torch', 'transformers', 'PIL', and 'requests' packages. "
29+
"Install them with: pip install torch transformers pillow requests"
30+
) from e
31+
32+
# Load a sample image
33+
url = (
34+
"https://huggingface.co/datasets/huggingface/documentation-images/"
35+
"resolve/main/cat_sample.jpeg"
36+
)
2737
image = Image.open(requests.get(url, stream=True, timeout=10).raw)
2838

39+
# Load pretrained model and processor
2940
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
3041
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
3142

43+
# Preprocess the image
3244
inputs = processor(images=image, return_tensors="pt")
3345

46+
# Run inference
3447
with torch.no_grad():
3548
outputs = model(**inputs)
3649
logits = outputs.logits

0 commit comments

Comments
 (0)