Skip to content

Commit fb29757

Browse files
committed
Add Vision Transformer demo for image classification
1 parent 1ba4fe1 commit fb29757

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

computer_vision/vision_transformer.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,15 @@
44
55
Classify images using a pretrained Vision Transformer (ViT)
66
from Hugging Face Transformers.
7-
8-
Can be used as a demo or imported in other scripts.
9-
10-
Source:
11-
https://huggingface.co/docs/transformers/model_doc/vit
127
"""
138

14-
try:
15-
import requests
16-
import torch
17-
from io import BytesIO
18-
from PIL import Image
19-
from transformers import ViTForImageClassification, ViTImageProcessor
20-
except ImportError as e:
21-
raise ImportError(
22-
"This module requires 'torch', 'transformers', 'PIL', and 'requests'. "
23-
"Install them with: pip install torch transformers pillow requests"
24-
) from e
9+
from io import BytesIO
10+
from typing import Optional
11+
12+
import requests
13+
import torch
14+
from PIL import Image, UnidentifiedImageError
15+
from transformers import ViTForImageClassification, ViTImageProcessor
2516

2617

2718
def classify_image(image: Image.Image) -> str:
@@ -39,21 +30,23 @@ def classify_image(image: Image.Image) -> str:
3930
return model.config.id2label[predicted_class_idx]
4031

4132

42-
def demo(url: str = None) -> None:
33+
def demo(url: Optional[str] = None) -> None:
4334
"""
4435
Run a demo using a sample image or provided URL.
45-
36+
4637
Args:
47-
url (str): URL of the image. If None, uses a default cat image.
38+
url (Optional[str]): URL of the image. If None, uses default cat image.
4839
"""
4940
if url is None:
50-
url = "https://images.unsplash.com/photo-1592194996308-7b43878e84a6" # default example image
41+
url = (
42+
"https://images.unsplash.com/photo-1592194996308-7b43878e84a6"
43+
) # default example image
5144

5245
try:
5346
response = requests.get(url, timeout=10)
5447
response.raise_for_status()
5548
image = Image.open(BytesIO(response.content))
56-
except Exception as e:
49+
except (requests.RequestException, UnidentifiedImageError) as e:
5750
print(f"Failed to load image from {url}. Error: {e}")
5851
return
5952

0 commit comments

Comments
 (0)