-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhandrecognition.py
More file actions
63 lines (56 loc) · 2.67 KB
/
handrecognition.py
File metadata and controls
63 lines (56 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from cv2 import imread
import os
class HandToText():
"""
Class to convert a line of handwritten text to computer text
Model used: https://huggingface.co/microsoft/trocr-large-handwritten
"""
def __init__(self, model="large") -> None:
print("Loading Model...")
if model=="turbo":
self.processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
self.model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten')
elif model=="base":
self.processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
self.model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
else:
#This is the best model and suggested the others do not perform that well
self.processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
self.model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
def predict(self, image_path):
"""
predict the text on the image
"""
image = imread(image_path)
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
def predict_folder(self, Path_to_folder):
"""
predict the text of all images inside a folder. The folder needs to consist of three folders:
Locations
Plants
Dates
Each of these need to consist of the respective images alligned in the correct order
"""
text_locations = []
files = sorted(os.listdir(Path_to_folder+"/Locations"))
for f in files:
image_path = os.path.join(Path_to_folder+"/Locations", f)
print("Looking at: ", f)
text_locations.append(self.predict(image_path))
text_plants = []
files = sorted(os.listdir(Path_to_folder+"/Plants"))
for f in files:
image_path = os.path.join(Path_to_folder+"/Plants", f)
print("Looking at: ", f)
text_plants.append(self.predict(image_path))
text_dates = []
files = sorted(os.listdir(Path_to_folder+"/Dates"))
for f in files:
image_path = os.path.join(Path_to_folder+"/Dates", f)
print("Looking at: ", f)
text_dates.append(self.predict(image_path))
return text_plants, text_locations, text_dates