1010from tensorflow import get_logger
1111from tensorflow .keras .models import Model
1212from tensorflow .keras .layers .experimental .preprocessing import StringLookup
13- from tensorflow .keras .layers import Input , Reshape , Dense , Dropout , Bidirectional , LSTM
13+ from tensorflow .keras .layers import Input , Reshape , Dense , Dropout , Bidirectional , LSTM , Flatten
1414from tensorflow .keras .backend import ctc_decode
1515from mobilenetv3 import MobileNetV3_Small
1616from tensorflow .strings import reduce_join
@@ -33,13 +33,13 @@ class Config:
3333 subattr_4_coords = [67 , 636 , 560 , 676 ]
3434
3535class OCR :
36- def __init__ (self , model_weight = 'mn_model_weight.h5' , scale_ratio = 1 ):
36+ def __init__ (self , model_weight = 'mn_model_weight.h5' , scale_ratio = 1 , ocr_model_artnames = None ):
3737 self .scale_ratio = scale_ratio
3838 self .characters = sorted (
3939 [
4040 * set (
4141 "" .join (
42- sum (ArtsInfo .ArtNames , [])
42+ sum (ArtsInfo .ArtNames [: - 2 ] , [])
4343 + ArtsInfo .TypeNames
4444 + list (ArtsInfo .MainAttrNames .values ())
4545 + list (ArtsInfo .SubAttrNames .values ())
@@ -63,12 +63,14 @@ def __init__(self, model_weight='mn_model_weight.h5', scale_ratio=1):
6363 self .max_length = 15
6464 self .build_model (input_shape = (self .width , self .height ))
6565 self .model .load_weights (model_weight )
66+ self .ocr_model_artnames = ocr_model_artnames
6667
6768 def detect_info (self , art_img ):
6869 info = self .extract_art_info (art_img )
6970 x = np .concatenate ([self .preprocess (info [key ]).T [None , :, :, None ] for key in sorted (info .keys ())], axis = 0 )
7071 y = self .model .predict (x )
7172 y = self .decode (y )
73+ y [3 ] = self .ocr_model_artnames .reg (x [3 ][None ])
7274 return {** {key :v for key , v in zip (sorted (info .keys ()), y )}, ** {'star' :self .detect_star (art_img )}}
7375
7476 def extract_art_info (self , art_img ):
@@ -198,4 +200,51 @@ def build_model(self, input_shape):
198200 output = Dense (len (self .characters ) + 2 , activation = "softmax" , name = "dense2" )(x )
199201
200202 # Define the model
201- self .model = Model (inputs = [input_img ], outputs = output , name = "ocr_model_v1" )
203+ self .model = Model (inputs = [input_img ], outputs = output , name = "ocr_model_v1" )
204+
205+ class OCR_artnames :
206+ def __init__ (self , model_weight = 'mn_model_weight_artnames.h5' ):
207+ self .artnames = sorted (set (sum (ArtsInfo .ArtNames , [])))
208+
209+ self .model = self .build_model (input_shape = (240 , 16 ))
210+ self .model .load_weights (model_weight )
211+
212+ def build_model (self , input_shape ):
213+ input_img = Input (
214+ shape = (input_shape [0 ], input_shape [1 ], 1 ), name = "image" , dtype = "float32"
215+ )
216+ mobilenet = MobileNetV3_Small (
217+ (input_shape [0 ], input_shape [1 ], 1 ), 0 , alpha = 1.0 , include_top = False
218+ ).build ()
219+ x = mobilenet (input_img )
220+ new_shape = ((input_shape [0 ] // 8 ), (input_shape [1 ] // 8 ) * 576 )
221+ x = Reshape (target_shape = new_shape , name = "reshape" )(x )
222+ x = Dense (64 , activation = "relu" , name = "dense1" )(x )
223+ x = Dropout (0.2 )(x )
224+
225+ # RNNs
226+ x = Bidirectional (LSTM (128 , return_sequences = True , dropout = 0.25 ))(x )
227+ x = Bidirectional (LSTM (64 , return_sequences = True , dropout = 0.25 ))(x )
228+
229+ # Output layer
230+ x = Flatten (name = "flatten" )(x )
231+ x = Dense (
232+ len (self .artnames ), activation = "softmax" , name = "dense2"
233+ )(x )
234+
235+ output = x
236+
237+ # Define the model
238+ model = Model (inputs = [input_img ], outputs = output , name = "ocr_model_artnames" )
239+
240+ return model
241+
242+ def decode_single (self , pred ):
243+ i = pred [0 ].argmax ()
244+ if pred [0 ][i ] > 0.75 :
245+ return self .artnames [i ]
246+ else :
247+ return 'Unknown'
248+
249+ def reg (self , x ):
250+ return self .decode_single (self .model .predict (x ))
0 commit comments