1919
2020from modules import shared
2121from tagger import settings # pylint: disable=import-error
22- from tagger .uiset import QData , IOData # pylint: disable=import-error
22+ from tagger .uiset import QData , IOData , supported_extensions , \
23+ format_output_filename
2324from preload import root_dir
2425from . import dbimutils # pylint: disable=import-error # noqa
2526
@@ -487,33 +488,37 @@ def load_model(self, model_path) -> None:
487488
488489class FromFileInterrogator (Interrogator ):
489490 """ Pseudo Interrogator reading preinterrogated tags files """
490- def __init__ (self , name : str , path : os .PathLike , val = 1.0 ) -> None :
491+ def __init__ (
492+ self , name : str , path , format = '[name].[output_extension]' , value = 1.0
493+ ) -> None :
491494 super ().__init__ (name )
492- self .path = path
493- self .val = val
494- self .tags = None
495+ self .path = Path (self .path )
496+ self .val = value
497+ self .format = format
498+ if format == Its .output_filename_format and path == '' :
499+ raise ValueError (f"tagsfiles will ne overwritten with { format } " )
500+ self .tags = {}
495501
496502 def load (self ) -> None :
497503 print (f'Loading { self .name } from { str (self .path )} ' )
504+ if self .path == '' :
505+ return
506+
498507 # self.path is a directory
499- if not os .path .isdir (self .path ):
508+ if not os .path .isdir (Path ( self .path ) ):
500509 raise ValueError (f'{ self .path } is not a directory' )
501- else :
502- self .tags = {}
503- for f in os .listdir (self .path ):
504- self .tags [f ] = {}
505- self .load_file (f )
506510
507511 def load_file (self , tags_file : str ) -> None :
508- image_name = str (tags_file ).split ('/' )[- 1 ].split ('.' )[0 ]
512+ basename = '.' .join (str (tags_file ).split ('/' )[- 1 ].split ('.' )[:- 1 ])
513+ self .tags [basename ] = {}
509514 with open (tags_file , 'r' ) as f :
510515 for line in f :
511- for x in map (str .split , line .split (',' )):
516+ for x in map (str .strip , line .split (',' )):
512517 if x [0 ] == '(' and x [- 1 ] == ')' and ':' in x :
513518 tag , val = x [1 :- 1 ].split (':' )
514- self .tags [image_name ][tag ] = float (val )
519+ self .tags [basename ][tag ] = float (val )
515520 else :
516- self .tags [image_name ][x ] = self .val
521+ self .tags [basename ][x ] = self .val
517522
518523 def unload (self ) -> None :
519524 self .tags = {}
@@ -525,7 +530,16 @@ def interrogate(
525530 Dict [str , float ], # rating confidences
526531 Dict [str , float ] # tag confidences
527532 ]:
528- return {}, self .tags [image .filename ]
533+ basename = '.' .join (image .filename .split ('/' )[- 1 ].split ('.' )[:- 1 ])
534+ path = Path (image .filename )
535+ tags_filename = format_output_filename (path , self .format )
536+ if self .path == '' :
537+ dir = path .parent
538+ else :
539+ dir = self .path
540+ tags_filename = os .path .join (dir , tags_filename )
541+ self .load_file (tags_filename )
542+ return {}, self .tags [basename ]
529543
530544
531545class WaifuDiffusionInterrogator (HFInterrogator ):
0 commit comments