Skip to content

Commit 486edca

Browse files
author
Roel Kluin
committed
fix fromfileinterrogator
1 parent 6d68c96 commit 486edca

4 files changed

Lines changed: 48 additions & 26 deletions

File tree

default/interrogators.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@
4949
"zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip"
5050
}
5151
},
52+
"FromFileInterrogator": {
53+
"[name].[hash:sha1].[output_extension]": {
54+
"format": "[name].[hash:sha1].[output_extension]",
55+
"path": "",
56+
"value": 1.0
57+
}
58+
},
5259
"MLDanbooruInterrogator": {
5360
"mld-caformer.dec-5-97527" : {
5461
"model_path" : "ml_caformer_m36_dec-5-97527.onnx",

json_schema/interrogators_v1_schema.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
"type": "object",
4141
"properties": {
4242
"path": { "type": "string" },
43-
"val": { "type": "number", "default": 1.0 }
43+
"format": { "type": "string", "default": "[name].[output_extension]" },
44+
"value": { "type": "number", "default": 1.0 }
4445
},
4546
"required": [
4647
"path"

tagger/interrogator.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from modules import shared
2121
from tagger import settings # pylint: disable=import-error
22-
from tagger.uiset import QData, IOData # pylint: disable=import-error
22+
from tagger.uiset import QData, IOData, supported_extensions, \
23+
format_output_filename
2324
from preload import root_dir
2425
from . import dbimutils # pylint: disable=import-error # noqa
2526

@@ -487,33 +488,37 @@ def load_model(self, model_path) -> None:
487488

488489
class FromFileInterrogator(Interrogator):
489490
""" Pseudo Interrogator reading preinterrogated tags files """
490-
def __init__(self, name: str, path: os.PathLike, val=1.0) -> None:
491+
def __init__(
492+
self, name: str, path, format='[name].[output_extension]', value=1.0
493+
) -> None:
491494
super().__init__(name)
492-
self.path = path
493-
self.val = val
494-
self.tags = None
495+
self.path = Path(self.path)
496+
self.val = value
497+
self.format = format
498+
if format == Its.output_filename_format and path == '':
499+
raise ValueError(f"tagsfiles will ne overwritten with {format}")
500+
self.tags = {}
495501

496502
def load(self) -> None:
497503
print(f'Loading {self.name} from {str(self.path)}')
504+
if self.path == '':
505+
return
506+
498507
# self.path is a directory
499-
if not os.path.isdir(self.path):
508+
if not os.path.isdir(Path(self.path)):
500509
raise ValueError(f'{self.path} is not a directory')
501-
else:
502-
self.tags = {}
503-
for f in os.listdir(self.path):
504-
self.tags[f] = {}
505-
self.load_file(f)
506510

507511
def load_file(self, tags_file: str) -> None:
508-
image_name = str(tags_file).split('/')[-1].split('.')[0]
512+
basename = '.'.join(str(tags_file).split('/')[-1].split('.')[:-1])
513+
self.tags[basename] = {}
509514
with open(tags_file, 'r') as f:
510515
for line in f:
511-
for x in map(str.split, line.split(',')):
516+
for x in map(str.strip, line.split(',')):
512517
if x[0] == '(' and x[-1] == ')' and ':' in x:
513518
tag, val = x[1:-1].split(':')
514-
self.tags[image_name][tag] = float(val)
519+
self.tags[basename][tag] = float(val)
515520
else:
516-
self.tags[image_name][x] = self.val
521+
self.tags[basename][x] = self.val
517522

518523
def unload(self) -> None:
519524
self.tags = {}
@@ -525,7 +530,16 @@ def interrogate(
525530
Dict[str, float], # rating confidences
526531
Dict[str, float] # tag confidences
527532
]:
528-
return {}, self.tags[image.filename]
533+
basename = '.'.join(image.filename.split('/')[-1].split('.')[:-1])
534+
path = Path(image.filename)
535+
tags_filename = format_output_filename(path, self.format)
536+
if self.path == '':
537+
dir = path.parent
538+
else:
539+
dir = self.path
540+
tags_filename = os.path.join(dir, tags_filename)
541+
self.load_file(tags_filename)
542+
return {}, self.tags[basename]
529543

530544

531545
class WaifuDiffusionInterrogator(HFInterrogator):

tagger/uiset.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
]
3838

3939

40+
def format_output_filename(path: Path, format='[name].[output_extension]') -> str:
41+
info = tags_format.Info(path, 'txt')
42+
fmt = partial(lambda info, m: tags_format.parse(m, info), info)
43+
return tags_format.pattern.sub(fmt, format)
44+
45+
4046
class IOData:
4147
""" data class for input and output paths """
4248
last_path_mtimes = None
@@ -119,7 +125,7 @@ def update_input_glob(cls, input_glob: str) -> None:
119125
path_mtimes = []
120126
for filename in glob(input_glob, recursive=recursive):
121127
if not os.path.isdir(filename):
122-
ext = os.path.splitext(filename)[1].lower()
128+
ext = os.path.splitext(filename)[-1].lower()
123129
if ext in supported_extensions:
124130
path_mtimes.append(os.path.getmtime(filename))
125131
paths.append(filename)
@@ -163,16 +169,11 @@ def set_batch_io(cls, paths: List[str]) -> None:
163169
base_dir_last_idx = path.parts.index(cls.base_dir_last)
164170
# format output filename
165171

166-
info = tags_format.Info(path, 'txt')
167-
fmt = partial(lambda info, m: tags_format.parse(m, info), info)
168-
169172
msg = 'Invalid output format'
170173
cls.err.discard(msg)
171174
try:
172-
formatted_output_filename = tags_format.pattern.sub(
173-
fmt,
174-
Its.output_filename_format
175-
)
175+
formatted_output_filename = format_output_filename(
176+
path, format=Its.output_filename_format)
176177
except (TypeError, ValueError):
177178
cls.err.add(msg)
178179

@@ -483,7 +484,6 @@ def correct_tag(cls, tag: str) -> str:
483484
if re_match(regex, tag):
484485
tag = re_sub(regex, cls.replace_tags[i], tag)
485486
break
486-
487487
return tag
488488

489489
@classmethod

0 commit comments

Comments
 (0)