From d252c2a297d5673adec78950a56546140f7c0425 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 23 Jul 2023 21:14:27 +0200 Subject: [PATCH 01/78] Make classes WaifuDiffusionInterrogator and MLDanbooruInterrogator a subclass of HFInterrogator. Allow more HuggingFace parameters for who can use them. The user can set HF_HUB_OFFLINE, then, or if the connection cannot be made, the fallback is the local directory. If that does not exist either, just stop the interrogation empty handed. --- tagger/interrogator.py | 189 +++++++++++++++++++++-------------------- tagger/settings.py | 16 ++-- tagger/utils.py | 5 +- 3 files changed, 107 insertions(+), 103 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index f438991..32fbf80 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -3,7 +3,8 @@ from pathlib import Path import io import json -from platform import system, uname +import inspect +from platform import uname from typing import Tuple, List, Dict, Callable from pandas import read_csv from PIL import Image, UnidentifiedImageError @@ -28,12 +29,10 @@ onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_cpu: - #import gc TF_DEVICE_NAME = '/cpu:0' onnxrt_providers.pop(0) print(f'== WD14 tagger: cpu, {uname()} ==') else: - #from numba import cuda TF_DEVICE_NAME = '/gpu:0' print(f'== WD14 tagger gpu, {uname()} ==') @@ -112,7 +111,7 @@ def __init__(self, name: str) -> None: # run_mode 0 is dry run, 1 means run (alternating), 2 means disabled self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2 - def load(self): + def load(self) -> bool: raise NotImplementedError() def large_batch_interrogate(self, images: List, dry_run=False) -> str: @@ -303,7 +302,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} import deepdanbooru.data as ddd @@ -337,34 +337,69 @@ def large_batch_interrogate(self, images: List, dry_run=False) -> str: raise NotImplementedError() -# FIXME this is silly, in what scenario would the env change from MacOS to -# another OS? TODO: remove if the author does not respond. -def get_onnxrt(): - try: - import onnxruntime - return onnxruntime - except ImportError: - # only one of these packages should be installed at one time in an env - # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime - # TODO: remove old package when the environment changes? - from launch import is_installed, run_pip - if not is_installed('onnxruntime'): - if system() == "Darwin": - package_name = "onnxruntime-silicon" +class HFInterrogator(Interrogator): + """ Interrogator for HuggingFace models """ + def __init__( + self, + name: str, + repo_id: str, + model_path: str, + tags_path: str, + ) -> None: + super().__init__(name) + self.repo_id = repo_id + self.model_path = model_path + self.tags_path = tags_path + self.model = None + self.local_model = None + self.local_tags = None + # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse + # and pass only the supported args. + + attrs = getattr(shared.opts, 'tagger_hf_hub_down_opts', + f'cache_dir="{Its.hf_cache}"') + attrs = [attr.split('=') for attr in map(str.strip, attrs.split(','))] + + signature = inspect.signature(hf_hub_download) + self.params = {} + for arg, val in attrs: + if arg in signature.parameters: + try: + tp = signature.parameters[arg].annotation(val) + self.params[arg] = tp(val) + except TypeError: + # unions, used for str of PathLike + if val[0] == val[-1] and val[0] in "'\"": + val = val[1:-1] + self.params[arg] = str(val) else: - package_name = "onnxruntime-gpu" - package = os.environ.get( - 'ONNXRUNTIME_PACKAGE', - package_name - ) + print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " + "Invalid for hf_hub_download() => ignored.") - run_pip(f'install {package}', 'onnxruntime') + def download(self) -> Tuple[str, str]: + print(f"Loading {self.name} model file from {self.repo_id}") + self.params['repo_id'] = self.repo_id - import onnxruntime - return onnxruntime + paths = [self.local_model, self.local_tags] + try: + for i, filename in enumerate([self.model_path, self.tags_path]): + self.params['filename'] = filename + paths[i] = hf_hub_download(**self.params) + except Exception as err: + if str(err)[:25] != "Offline mode is enabled.": + print(f"hf_hub_download({self.params}: {err}") -class WaifuDiffusionInterrogator(Interrogator): + return paths + + def load_model(self, model_path) -> None: + import onnxruntime + self.model = onnxruntime.InferenceSession(model_path, + providers=onnxrt_providers) + print(f'Loaded {self.name} model from {model_path}') + + +class WaifuDiffusionInterrogator(HFInterrogator): """ Interrogator for Waifu Diffusion models """ def __init__( self, @@ -372,43 +407,17 @@ def __init__( model_path='model.onnx', tags_path='selected_tags.csv', repo_id=None, - is_hf=True, ) -> None: - super().__init__(name) - self.repo_id = repo_id - self.model_path = model_path - self.tags_path = tags_path - self.tags = None - self.model = None + super().__init__(name, repo_id, model_path, tags_path) self.tags = None - self.local_model = None - self.local_tags = None - self.is_hf = is_hf - - def download(self) -> None: - mdir = Path(shared.models_path, 'interrogators') - if self.is_hf: - cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) - print(f"Loading {self.name} model file from {self.repo_id}, " - f"{self.model_path}") - - model_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.model_path, - cache_dir=cache) - tags_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.tags_path, - cache_dir=cache) - else: - model_path = self.local_model - tags_path = self.local_tags + def update_model_json(self, model_path, tags_path): download_model = { 'name': self.name, 'model_path': model_path, 'tags_path': tags_path, } + mdir = Path(shared.models_path, 'interrogators') mpath = Path(mdir, 'model.json') data = [download_model] @@ -429,16 +438,22 @@ def download(self) -> None: with io.open(mpath, 'w', encoding='utf-8') as filename: json.dump(data, filename) - return model_path, tags_path - def load(self) -> None: + def load(self) -> bool: model_path, tags_path = self.download() - ort = get_onnxrt() - self.model = ort.InferenceSession(model_path, - providers=onnxrt_providers) - print(f'Loaded {self.name} model from {self.repo_id}') + if not os.path.exists(model_path): + print(f'Model path {model_path} not found.') + return False + + if not os.path.exists(tags_path): + print(f'Tags path {tags_path} not found.') + return False + + self.load_model(model_path) + self.update_model_json(model_path, tags_path) self.tags = read_csv(tags_path) + return True def interrogate( self, @@ -449,7 +464,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} # code for converting the image and running the model is taken from the # link below. thanks, SmilingWolf! @@ -541,7 +557,8 @@ def large_batch_interrogate(self, images, dry_run=True) -> None: # init model if not hasattr(self, 'model') or self.model is None: - self.load() + if not self.load(): + return os.environ["TF_XLA_FLAGS"] = '--tf_xla_auto_jit=2 '\ '--tf_xla_cpu_global_jit' @@ -587,7 +604,7 @@ def pred_model(model): del os.environ["TF_XLA_FLAGS"] -class MLDanbooruInterrogator(Interrogator): +class MLDanbooruInterrogator(HFInterrogator): """ Interrogator for the MLDanbooru model. """ def __init__( self, @@ -596,40 +613,25 @@ def __init__( model_path: str, tags_path='classes.json', ) -> None: - super().__init__(name) - self.model_path = model_path - self.tags_path = tags_path - self.repo_id = repo_id + super().__init__(name, repo_id, model_path, tags_path) self.tags = None - self.model = None - - def download(self) -> Tuple[str, str]: - print(f"Loading {self.name} model file from {self.repo_id}") - cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) - model_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.model_path, - cache_dir=cache - ) - tags_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.tags_path, - cache_dir=cache - ) - return model_path, tags_path - - def load(self) -> None: + def load(self) -> bool: model_path, tags_path = self.download() - ort = get_onnxrt() - self.model = ort.InferenceSession(model_path, - providers=onnxrt_providers) - print(f'Loaded {self.name} model from {model_path}') + if not os.path.exists(model_path): + print(f'Model path {model_path} not found.') + return False + + if not os.path.exists(tags_path): + print(f'Tags path {tags_path} not found.') + return False with open(tags_path, 'r', encoding='utf-8') as filen: self.tags = json.load(filen) + return True + def interrogate( self, image: Image @@ -639,7 +641,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} image = dbimutils.fill_transparent(image) image = dbimutils.resize(image, 448) # TODO CUSTOMIZE diff --git a/tagger/settings.py b/tagger/settings.py index 8510468..07b58fe 100644 --- a/tagger/settings.py +++ b/tagger/settings.py @@ -3,14 +3,17 @@ from typing import List from modules import shared # pylint: disable=import-error from gradio import inputs as gr +from huggingface_hub import hf_hub_download # kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400! DEFAULT_KAMOJIS = '0_0, (o)_(o), +_+, +_-, ._., _, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||' # pylint: disable=line-too-long # noqa: E501 DEFAULT_OFF = '[name].[output_extension]' -HF_CACHE = os.environ.get('HF_HOME', os.environ.get('HUGGINGFACE_HUB_CACHE', - str(os.path.join(shared.models_path, 'interrogators')))) +HF_CACHE = os.environ.get( + 'HUGGINGFACE_HUB_CACHE', # defaults to "$HF_HOME/hub" + str(os.path.join(shared.models_path, 'interrogators'))) + def slider_wrapper(value, elem_id, **kwargs): # required or else gradio will throw errors @@ -121,13 +124,12 @@ def on_ui_settings(): section=section, ), ) - # see huggingface_hub guides/manage-cache shared.opts.add_option( - key='tagger_hf_cache_dir', + key='tagger_hf_hub_down_opts', info=shared.OptionInfo( - HF_CACHE, - label='HuggingFace cache directory, ' - 'see huggingface_hub guides/manage-cache', + str(f'cache_dir="{HF_CACHE}"'), + label='HuggingFace parameters, Comma delimited: arg=value, ' + 'see huggingface_hub docs for available or leave alone.', section=section, ), ) diff --git a/tagger/utils.py b/tagger/utils.py index 9b05625..be2d673 100644 --- a/tagger/utils.py +++ b/tagger/utils.py @@ -105,12 +105,11 @@ def tag_select_csvs_up_front(k): if path.name == 'wd-v1-4-convnextv2-tagger-v2': interrogators[path.name] = WaifuDiffusionInterrogator( path.name, - repo_id='SmilingWolf/SW-CV-ModelZoo', - is_hf=False + repo_id='SmilingWolf/SW-CV-ModelZoo' ) elif path.name == 'Z3D-E621-Convnext': interrogators[path.name] = WaifuDiffusionInterrogator( - 'Z3D-E621-Convnext', is_hf=False) + 'Z3D-E621-Convnext') else: raise NotImplementedError(f"Add {path.name} resolution similar" "to above here") From 714c335fd17a0be3391a7eb7f42a1e83e28145fc Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 23 Jul 2023 21:55:10 +0200 Subject: [PATCH 02/78] a little more work is required --- tagger/interrogator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 32fbf80..c53b5f5 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -368,10 +368,15 @@ def __init__( tp = signature.parameters[arg].annotation(val) self.params[arg] = tp(val) except TypeError: - # unions, used for str of PathLike - if val[0] == val[-1] and val[0] in "'\"": - val = val[1:-1] - self.params[arg] = str(val) + if val == 'None': + self.params[arg] = None + elif arg == 'token' and val in {'True', 'False'}: + self.params[arg] = val == 'True' + else: + # unions, used for str or PathLike + if val[0] == val[-1] and val[0] in "'\"": + val = val[1:-1] + self.params[arg] = str(val) else: print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " "Invalid for hf_hub_download() => ignored.") From e4c056a975c646fe26439dd0b4d8e111baa29c64 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 23 Jul 2023 23:11:00 +0200 Subject: [PATCH 03/78] add interrogators.json move refresh to interrogator as a static, and pick up the configured interrogators there. presets in tagger/presets.py and tagger/utils.py can go. write info alongside model so we can check its up to date status --- .gitignore | 4 +- interrogators.json | 69 +++++++++++++++ tagger/api.py | 10 +-- tagger/interrogator.py | 188 +++++++++++++++++++++++++++++++++++------ tagger/preset.py | 7 ++ tagger/ui.py | 68 ++++++++------- tagger/utils.py | 124 --------------------------- 7 files changed, 278 insertions(+), 192 deletions(-) create mode 100644 interrogators.json delete mode 100644 tagger/utils.py diff --git a/.gitignore b/.gitignore index fd6106c..8ea680e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,5 @@ __pycache__/ .vscode/ .venv/ .env - -presets/ \ No newline at end of file +presets/ +addons/ diff --git a/interrogators.json b/interrogators.json new file mode 100644 index 0000000..ce55031 --- /dev/null +++ b/interrogators.json @@ -0,0 +1,69 @@ +{ + "mld-caformer.dec-5-97527" : { + "class" : "MLDanbooruInterrogator", + "repo_specs" : { + "model_path" : "ml_caformer_m36_dec-5-97527.onnx", + "name" : "ML-Danbooru Caformer dec-5-97527", + "repo_id" : "deepghs/ml-danbooru-onnx" + } + }, + "mld-tresnetd.6-30000" : { + "class" : "MLDanbooruInterrogator", + "repo_specs" : { + "model_path" : "TResnet-D-FLq_ema_6-30000.onnx", + "name" : "ML-Danbooru TResNet-D 6-30000", + "repo_id" : "deepghs/ml-danbooru-onnx" + } + }, + "wd-v1-4-moat-tagger.v2" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 moat tagger v2", + "repo_id" : "SmilingWolf/wd-v1-4-moat-tagger-v2" + } + }, + "wd14-convnext.v1" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ConvNeXT v1", + "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger" + } + }, + "wd14-convnext.v2" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ConvNeXT v2", + "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger-v2" + } + }, + "wd14-convnextv2.v1" : { + "remark" : "the repo_id name is misleading, but it's v1", + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ConvNeXTV2 v1", + "repo_id" : "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" + } + }, + "wd14-swinv2-v1" : { + "remark" : "the repo_id name is misleading, but it's v1", + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 SwinV2 v1", + "repo_id" : "SmilingWolf/wd-v1-4-swinv2-tagger-v2" + } + }, + "wd14-vit.v1" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ViT v1", + "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger" + } + }, + "wd14-vit.v2" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ViT v2", + "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger-v2" + } + } +} diff --git a/tagger/api.py b/tagger/api.py index c4f6e6b..be3c2cc 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -9,9 +9,9 @@ from fastapi import FastAPI, Depends, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials -from tagger import utils # pylint: disable=import-error from tagger import api_models as models # pylint: disable=import-error from tagger.uiset import QData # pylint: disable=import-error +from tagger.interrogator import Interrogator class Api: @@ -78,11 +78,11 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if req.image is None: raise HTTPException(404, 'Image not found') - if req.model not in utils.interrogators.keys(): + if req.model not in Interrogator.entries.keys(): raise HTTPException(404, 'Model not found') image = decode_base64_to_image(req.image) - interrogator = utils.interrogators[req.model] + interrogator = Interrogator.entries[req.model] with self.queue_lock: QData.tags.clear() @@ -101,13 +101,13 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): def endpoint_interrogators(self): return models.InterrogatorsResponse( - models=list(utils.interrogators.keys()) + models=list(Interrogator.entries.keys()) ) def endpoint_unload_interrogators(self): unloaded_models = 0 - for i in utils.interrogators.values(): + for i in Interrogator.entries.values(): if i.unload(): unloaded_models = unloaded_models + 1 diff --git a/tagger/interrogator.py b/tagger/interrogator.py index c53b5f5..624db6b 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -10,10 +10,11 @@ from PIL import Image, UnidentifiedImageError from numpy import asarray, float32, expand_dims, exp from tqdm import tqdm - from huggingface_hub import hf_hub_download -from modules import shared # pylint: disable=import-error +from modules.paths import extensions_dir +from modules import shared +from preload import default_ddp_path, default_onnx_path from tagger import settings # pylint: disable=import-error from tagger.uiset import QData, IOData # pylint: disable=import-error from . import dbimutils # pylint: disable=import-error # noqa @@ -59,6 +60,7 @@ class Interrogator: } output = None odd_increment = 0 + entries = {} @classmethod def flip(cls, key): @@ -91,6 +93,98 @@ def setter(val) -> Tuple[str, str]: return setter + @classmethod + def refresh(cls) -> List[str]: + """Refreshes the interrogator entries""" + if len(cls.entries) == 0: + it_path = Path(os.path.join( + extensions_dir, + 'stable-diffusion-webui-wd14-tagger/interrogators.json' + )) + if not it_path.exists(): + raise FileNotFoundError(f'{it_path} not found.') + + with open(it_path) as filename: + raw = json.load(filename) + + for name, it in raw.items(): + if it["class"] == "DeepDanbooruInterrogator": + It_type = DeepDanbooruInterrogator + elif it["class"] == "WaifuDiffusionInterrogator": + It_type = WaifuDiffusionInterrogator + elif it["class"] == "MLDanbooruInterrogator": + It_type = MLDanbooruInterrogator + else: + raise ValueError(f'Unimplemented: {it["class"]}') + + cls.entries[name] = It_type(**it["repo_specs"]) + + # load deepdanbooru project + ddp_path = getattr(shared.cmd_opts, 'deepdanbooru_projects_path', + default_ddp_path) + onnx_path = getattr(shared.cmd_opts, 'onnxtagger_path', + default_onnx_path) + os.makedirs(ddp_path, exist_ok=True) + os.makedirs(onnx_path, exist_ok=True) + + for path in os.scandir(ddp_path): + print(f"Scanning {path} as deepdanbooru project") + if not path.is_dir(): + print(f"Warning: {path} is not a directory, skipped") + continue + + if not Path(path, 'project.json').is_file(): + print(f"Warning: {path} has no project.json, skipped") + continue + + cls.entries[path.name] = DeepDanbooruInterrogator(path.name, path) + # scan for onnx models as well + for path in os.scandir(onnx_path): + print(f"Scanning {path} as onnx model") + if not path.is_dir(): + print(f"Warning: {path} is not a directory, skipped") + continue + + onnx_files = [] + for file_name in os.scandir(path): + if file_name.name.endswith('.onnx'): + onnx_files.append(file_name) + + if len(onnx_files) != 1: + print(f"Warning: {path}: multiple .onnx models => skipped") + continue + local_path = Path(path, onnx_files[0].name) + + csv = [x for x in os.scandir(path) if x.name.endswith('.csv')] + if len(csv) == 0: + print(f"Warning: {path}: no selected tags .csv file, skipped") + continue + + def tag_select_csvs_up_front(k): + k = k.name.lower() + return -1 if "tag" in k or "select" in k else 1 + + csv.sort(key=tag_select_csvs_up_front) + tags_path = Path(path, csv[0]) + + if path.name not in cls.entries: + if path.name == 'wd-v1-4-convnextv2-tagger-v2': + cls.entries[path.name] = WaifuDiffusionInterrogator( + path.name, + repo_id='SmilingWolf/SW-CV-ModelZoo' + ) + elif path.name == 'Z3D-E621-Convnext': + cls.entries[path.name] = WaifuDiffusionInterrogator( + 'Z3D-E621-Convnext') + else: + raise NotImplementedError(f"Add {path.name} resolution " + "similar to above here") + + cls.entries[path.name].local_model = str(local_path) + cls.entries[path.name].local_tags = str(tags_path) + + return sorted(i.name for i in cls.entries.values()) + @staticmethod def load_image(path: str) -> Image: try: @@ -110,6 +204,9 @@ def __init__(self, name: str) -> None: self.tags = None # run_mode 0 is dry run, 1 means run (alternating), 2 means disabled self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2 + # default path if not overridden by download + self.local_model = None + self.local_tags = None def load(self) -> bool: raise NotImplementedError() @@ -342,59 +439,96 @@ class HFInterrogator(Interrogator): def __init__( self, name: str, - repo_id: str, model_path: str, tags_path: str, + **kwargs, ) -> None: super().__init__(name) - self.repo_id = repo_id self.model_path = model_path self.tags_path = tags_path self.model = None - self.local_model = None - self.local_tags = None # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse # and pass only the supported args. + self.repo_specs = {'repo_id', 'revision', 'library_name', + 'library_version'} + self.hf_params = {} + for k in self.repo_specs: + if k in kwargs: + self.hf_params[k] = kwargs[k] + + if 'repo_id' not in self.hf_params: + print(f"Error: interrogatos.json: HuggingFace model {self.name} " + "lacks a repo_id. If not already local, download may fail.") + attrs = getattr(shared.opts, 'tagger_hf_hub_down_opts', f'cache_dir="{Its.hf_cache}"') attrs = [attr.split('=') for attr in map(str.strip, attrs.split(','))] signature = inspect.signature(hf_hub_download) - self.params = {} for arg, val in attrs: - if arg in signature.parameters: + if arg == 'filename' or arg in self.repo_specs: + + print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " + "Specific options need to go in the interrogators.json.") + + elif arg in signature.parameters: try: tp = signature.parameters[arg].annotation(val) - self.params[arg] = tp(val) + self.hf_params[arg] = tp(val) + except TypeError: + # unions, used for str or PathLike and a few. if val == 'None': - self.params[arg] = None + self.hf_params[arg] = None elif arg == 'token' and val in {'True', 'False'}: - self.params[arg] = val == 'True' + self.hf_params[arg] = val == 'True' else: - # unions, used for str or PathLike if val[0] == val[-1] and val[0] in "'\"": val = val[1:-1] - self.params[arg] = str(val) + self.hf_params[arg] = str(val) else: print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " "Invalid for hf_hub_download() => ignored.") def download(self) -> Tuple[str, str]: - print(f"Loading {self.name} model file from {self.repo_id}") - self.params['repo_id'] = self.repo_id + repo_id = self.hf_params.get('repo_id', '(?)') + print(f"Loading {self.name} model file from {repo_id}") paths = [self.local_model, self.local_tags] - try: - for i, filename in enumerate([self.model_path, self.tags_path]): - self.params['filename'] = filename - paths[i] = hf_hub_download(**self.params) - except Exception as err: - if str(err)[:25] != "Offline mode is enabled.": - print(f"hf_hub_download({self.params}: {err}") + data = {} + for k in self.repo_specs: + if k in self.hf_params: + data[k] = self.hf_params[k] + + # check if the model is up to date + info_path = Path(self.local_model).with_suffix('.info') + if info_path.exists(): + + if all(os.path.exists(p) for p in paths): + with open(info_path, 'r') as filen: + try: + old_data = json.load(filen) + if old_data == data: + print(f"Model {self.name} is up to date.") + return paths + except json.decoder.JSONDecodeError: + pass + try: + for i, filen in enumerate([self.model_path, self.tags_path]): + self.hf_params['filename'] = filen + paths[i] = hf_hub_download(**self.hf_params) + except Exception as err: + if str(err)[:25] != "Offline mode is enabled.": + print(f"hf_hub_download({self.hf_params}: {err}") + return paths + + # write the repo_specs to a json alongside the model so we can + # check if the model is up to date + with open(info_path, 'w') as filen: + json.dump(data, filen) return paths def load_model(self, model_path) -> None: @@ -411,9 +545,9 @@ def __init__( name: str, model_path='model.onnx', tags_path='selected_tags.csv', - repo_id=None, + **kwargs, ) -> None: - super().__init__(name, repo_id, model_path, tags_path) + super().__init__(name, model_path, tags_path, **kwargs) self.tags = None def update_model_json(self, model_path, tags_path): @@ -614,11 +748,11 @@ class MLDanbooruInterrogator(HFInterrogator): def __init__( self, name: str, - repo_id: str, model_path: str, tags_path='classes.json', + **kwargs ) -> None: - super().__init__(name, repo_id, model_path, tags_path) + super().__init__(name, model_path, tags_path, **kwargs) self.tags = None def load(self) -> bool: @@ -632,6 +766,8 @@ def load(self) -> bool: print(f'Tags path {tags_path} not found.') return False + self.load_model(model_path) + with open(tags_path, 'r', encoding='utf-8') as filen: self.tags = json.load(filen) diff --git a/tagger/preset.py b/tagger/preset.py index 9189535..714aa13 100644 --- a/tagger/preset.py +++ b/tagger/preset.py @@ -6,6 +6,8 @@ from pathlib import Path from gradio.context import Context from modules.images import sanitize_filename_part # pylint: disable=E0401 +from modules.paths import extensions_dir +from modules import scripts PresetDict = Dict[str, Dict[str, any]] @@ -106,3 +108,8 @@ def list(self) -> List[str]: presets.append(self.default_filename) return presets + + +preset = Preset(Path( + os.path.join(extensions_dir, 'stable-diffusion-webui-wd14-tagger/presets') +)) diff --git a/tagger/ui.py b/tagger/ui.py index 49a657c..ab47091 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -11,9 +11,9 @@ from modules import generation_parameters_copypaste as parameters_copypaste # pylint: disable=import-error # noqa from webui import wrap_gradio_gpu_call # pylint: disable=import-error -from tagger import utils # pylint: disable=import-error from tagger.interrogator import Interrogator as It # pylint: disable=E0401 from tagger.uiset import IOData, QData # pylint: disable=import-error +from tagger.preset import preset TAG_INPUTS = ["add", "keep", "exclude", "search", "replace"] COMMON_OUTPUT = Tuple[ @@ -30,7 +30,7 @@ def unload_interrogators() -> List[str]: unloaded_models = 0 remaining_models = '' - for i in utils.interrogators.values(): + for i in It.entries.values(): if i.unload(): unloaded_models = unloaded_models + 1 elif i.model is not None: @@ -64,7 +64,7 @@ def on_interrogate( getattr(QData, "update_" + part)(val) It.input[part] = val - interrogator: It = next((i for i in utils.interrogators.values() if + interrogator: It = next((i for i in It.entries.values() if i.name == name), None) if interrogator is None: return None, None, None, None, None, f"'{name}': invalid interrogator" @@ -81,9 +81,10 @@ def on_interrogate_image(*args) -> COMMON_OUTPUT: # hack brcause image interrogaion occurs twice It.odd_increment = It.odd_increment + 1 if It.odd_increment & 1 == 1: - return (None, None, None, None, None, '') + return (None, None, None, None, None, '') return on_interrogate_image_submit(*args) + def on_interrogate_image_submit( image: Image, name: str, filt: str, *args ) -> COMMON_OUTPUT: @@ -95,7 +96,7 @@ def on_interrogate_image_submit( if image is None: return None, None, None, None, None, 'No image selected' - interrogator: It = next((i for i in utils.interrogators.values() if + interrogator: It = next((i for i in It.entries.values() if i.name == name), None) if interrogator is None: return None, None, None, None, None, f"'{name}': invalid interrogator" @@ -179,13 +180,13 @@ def on_ui_tabs(): ) with gr.TabItem(label='Batch from directory'): - input_glob = utils.preset.component( + input_glob = preset.component( gr.Textbox, value='', label='Input directory - See also settings tab.', placeholder='/path/to/images or to/images/**/*' ) - output_dir = utils.preset.component( + output_dir = preset.component( gr.Textbox, value=It.input["output_dir"], label='Output directory', @@ -199,7 +200,7 @@ def on_ui_tabs(): ) with gr.Row(variant='compact'): with gr.Column(variant='panel'): - large_query = utils.preset.component( + large_query = preset.component( gr.Checkbox, label='huge batch query (TF 2.10, ' 'experimental)', @@ -208,7 +209,7 @@ def on_ui_tabs(): version.parse('2.10') ) with gr.Column(variant='panel'): - save_tags = utils.preset.component( + save_tags = preset.component( gr.Checkbox, label='Save to tags files', value=True @@ -222,7 +223,7 @@ def on_ui_tabs(): # preset selector with gr.Row(variant='compact'): - available_presets = utils.preset.list() + available_presets = preset.list() selected_preset = gr.Dropdown( label='Preset', choices=available_presets, @@ -236,19 +237,15 @@ def on_ui_tabs(): ui.create_refresh_button( selected_preset, lambda: None, - lambda: {'choices': utils.preset.list()}, + lambda: {'choices': preset.list()}, 'refresh_preset' ) # interrogator selector with gr.Column(): with gr.Row(variant='compact'): - def refresh(): - utils.refresh_interrogators() - return sorted(x.name for x in utils.interrogators - .values()) - interrogator_names = refresh() - interrogator = utils.preset.component( + interrogator_names = It.refresh() + interrogator = preset.component( gr.Dropdown, label='Interrogator', choices=interrogator_names, @@ -262,44 +259,44 @@ def refresh(): ui.create_refresh_button( interrogator, lambda: None, - lambda: {'choices': refresh()}, + lambda: {'choices': It.refresh()}, 'refresh_interrogator' ) unload_all_models = gr.Button( value='Unload all interrogate models' ) - tag_input["add"] = utils.preset.component( + tag_input["add"] = preset.component( gr.Textbox, label='Additional tags (comma split)', elem_id='additional-tags' ) with gr.Row(variant='compact'): with gr.Column(variant='compact'): - threshold = utils.preset.component( + threshold = preset.component( gr.Slider, label='Weight threshold', minimum=0, maximum=1, value=QData.threshold ) - cumulative = utils.preset.component( + cumulative = preset.component( gr.Checkbox, label='Combine interrogations', value=False ) - tag_input["search"] = utils.preset.component( + tag_input["search"] = preset.component( gr.Textbox, label='Search tag, .. ->', elem_id='search-tags' ) - tag_input["keep"] = utils.preset.component( + tag_input["keep"] = preset.component( gr.Textbox, label='Kept tag, ..', elem_id='keep-tags' ) with gr.Column(variant='compact'): - tag_frac_threshold = utils.preset.component( + tag_frac_threshold = preset.component( gr.Slider, label='Min tag fraction in batch and ' 'interrogations', @@ -307,17 +304,17 @@ def refresh(): maximum=1, value=QData.tag_frac_threshold, ) - unload_after = utils.preset.component( + unload_after = preset.component( gr.Checkbox, label='Unload model after running', value=False ) - tag_input["replace"] = utils.preset.component( + tag_input["replace"] = preset.component( gr.Textbox, label='-> Replace tag, ..', elem_id='replace-tags' ) - tag_input["exclude"] = utils.preset.component( + tag_input["exclude"] = preset.component( gr.Textbox, label='Exclude tag, ..', elem_id='exclude-tags' @@ -336,7 +333,7 @@ def refresh(): variant='secondary' ) with gr.Column(variant='compact'): - tag_search_selection = utils.preset.component( + tag_search_selection = preset.component( gr.Textbox, label='Multi string search: part1, part2.. ' '(Enter key to update)', @@ -396,11 +393,11 @@ def refresh(): save_tags.input(fn=IOData.flip_save_tags(), inputs=[], outputs=[]) # Preset and unload buttons - selected_preset.change(fn=utils.preset.apply, inputs=[selected_preset], - outputs=[*utils.preset.components, info]) + selected_preset.change(fn=preset.apply, inputs=[selected_preset], + outputs=[*preset.components, info]) - save_preset_button.click(fn=utils.preset.save, inputs=[selected_preset, - *utils.preset.components], outputs=[info]) + save_preset_button.click(fn=preset.save, inputs=[selected_preset, + *preset.components], outputs=[info]) unload_all_models.click(fn=unload_interrogators, outputs=[info]) @@ -451,11 +448,12 @@ def refresh(): [tag_input[tag] for tag in TAG_INPUTS] # interrogation events - image_submit.click(fn=wrap_gradio_gpu_call(on_interrogate_image_submit), - inputs=[image] + common_input, outputs=common_output) + image_submit.click( + fn=wrap_gradio_gpu_call(on_interrogate_image_submit), + inputs=[image] + common_input, outputs=common_output) image.change(fn=wrap_gradio_gpu_call(on_interrogate_image), - inputs=[image] + common_input, outputs=common_output) + inputs=[image] + common_input, outputs=common_output) batch_submit.click(fn=wrap_gradio_gpu_call(on_interrogate), inputs=[input_glob, output_dir] + common_input, diff --git a/tagger/utils.py b/tagger/utils.py deleted file mode 100644 index be2d673..0000000 --- a/tagger/utils.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Utility functions for the tagger module""" -import os - -from typing import List, Dict -from pathlib import Path - -from modules import shared, scripts # pylint: disable=import-error -from preload import default_ddp_path, default_onnx_path # pylint: disable=E0401 # noqa: E501 -from tagger.preset import Preset # pylint: disable=import-error -from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, \ - MLDanbooruInterrogator # pylint: disable=E0401 # noqa: E501 -from tagger.interrogator import WaifuDiffusionInterrogator # pylint: disable=E0401 # noqa: E501 - -preset = Preset(Path(scripts.basedir(), 'presets')) - -interrogators: Dict[str, Interrogator] = { - 'wd14-vit.v1': WaifuDiffusionInterrogator( - 'WD14 ViT v1', - repo_id='SmilingWolf/wd-v1-4-vit-tagger' - ), - 'wd14-vit.v2': WaifuDiffusionInterrogator( - 'WD14 ViT v2', - repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2', - ), - 'wd14-convnext.v1': WaifuDiffusionInterrogator( - 'WD14 ConvNeXT v1', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger' - ), - 'wd14-convnext.v2': WaifuDiffusionInterrogator( - 'WD14 ConvNeXT v2', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2', - ), - 'wd14-convnextv2.v1': WaifuDiffusionInterrogator( - 'WD14 ConvNeXTV2 v1', - # the name is misleading, but it's v1 - repo_id='SmilingWolf/wd-v1-4-convnextv2-tagger-v2', - ), - 'wd14-swinv2-v1': WaifuDiffusionInterrogator( - 'WD14 SwinV2 v1', - # again misleading name - repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2', - ), - 'wd-v1-4-moat-tagger.v2': WaifuDiffusionInterrogator( - 'WD14 moat tagger v2', - repo_id='SmilingWolf/wd-v1-4-moat-tagger-v2' - ), - 'mld-caformer.dec-5-97527': MLDanbooruInterrogator( - 'ML-Danbooru Caformer dec-5-97527', - repo_id='deepghs/ml-danbooru-onnx', - model_path='ml_caformer_m36_dec-5-97527.onnx' - ), - 'mld-tresnetd.6-30000': MLDanbooruInterrogator( - 'ML-Danbooru TResNet-D 6-30000', - repo_id='deepghs/ml-danbooru-onnx', - model_path='TResnet-D-FLq_ema_6-30000.onnx' - ), -} - - -def refresh_interrogators() -> List[str]: - """Refreshes the interrogators list""" - # load deepdanbooru project - ddp_path = getattr(shared.cmd_opts, 'deepdanbooru_projects_path', - default_ddp_path) - onnx_path = getattr(shared.cmd_opts, 'onnxtagger_path', default_onnx_path) - os.makedirs(ddp_path, exist_ok=True) - os.makedirs(onnx_path, exist_ok=True) - - for path in os.scandir(ddp_path): - print(f"Scanning {path} as deepdanbooru project") - if not path.is_dir(): - print(f"Warning: {path} is not a directory, skipped") - continue - - if not Path(path, 'project.json').is_file(): - print(f"Warning: {path} has no project.json, skipped") - continue - - interrogators[path.name] = DeepDanbooruInterrogator(path.name, path) - # scan for onnx models as well - for path in os.scandir(onnx_path): - print(f"Scanning {path} as onnx model") - if not path.is_dir(): - print(f"Warning: {path} is not a directory, skipped") - continue - - onnx_files = [x for x in os.scandir(path) if x.name.endswith('.onnx')] - if len(onnx_files) != 1: - print(f"Warning: {path} requires exactly one .onnx model, skipped") - continue - local_path = Path(path, onnx_files[0].name) - - csv = [x for x in os.scandir(path) if x.name.endswith('.csv')] - if len(csv) == 0: - print(f"Warning: {path} has no selected tags .csv file, skipped") - continue - - def tag_select_csvs_up_front(k): - sum(-1 if t in k.name.lower() else 1 for t in ["tag", "select"]) - - csv.sort(key=tag_select_csvs_up_front) - tags_path = Path(path, csv[0]) - - if path.name not in interrogators: - if path.name == 'wd-v1-4-convnextv2-tagger-v2': - interrogators[path.name] = WaifuDiffusionInterrogator( - path.name, - repo_id='SmilingWolf/SW-CV-ModelZoo' - ) - elif path.name == 'Z3D-E621-Convnext': - interrogators[path.name] = WaifuDiffusionInterrogator( - 'Z3D-E621-Convnext') - else: - raise NotImplementedError(f"Add {path.name} resolution similar" - "to above here") - - interrogators[path.name].local_model = str(local_path) - interrogators[path.name].local_tags = str(tags_path) - - return sorted(interrogators.keys()) - - -def split_str(string: str, separator=',') -> List[str]: - return [x.strip() for x in string.split(separator) if x] From 520915bf1edf8a0dc30b85ab4c85bb4b1af11be5 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 25 Jul 2023 16:53:13 +0200 Subject: [PATCH 04/78] cleanup --- tagger/interrogator.py | 7 ++----- tagger/uiset.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 624db6b..a48e363 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -207,6 +207,7 @@ def __init__(self, name: str) -> None: # default path if not overridden by download self.local_model = None self.local_tags = None + # XXX don't Interrogator.refresh()-ception here def load(self) -> bool: raise NotImplementedError() @@ -300,11 +301,7 @@ def batch_interrogate_image(self, index: int) -> None: def batch_interrogate(self) -> None: """ Interrogate all images in the input list """ - QData.tags.clear() - QData.ratings.clear() - QData.image_dups.clear() - if not Interrogator.input["cumulative"]: - QData.in_db.clear() + QData.clear(Interrogator.input["cumulative"]) if Interrogator.input["large_query"] is True and self.run_mode < 2: # TODO: write specified tags files instead of simple .txt diff --git a/tagger/uiset.py b/tagger/uiset.py index 3aa4052..dc37003 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -255,8 +255,8 @@ def set_attr(cls, current: str, tag: str) -> None: def clear(cls, mode: int) -> None: """ clear tags and ratings """ cls.tags.clear() - cls.discarded_tags.clear() cls.ratings.clear() + cls.discarded_tags.clear() cls.for_tags_file.clear() if mode > 0: cls.in_db.clear() From 1d71a8e04e13fa9653aa6f6e5feda84439019b52 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 25 Jul 2023 22:15:22 +0200 Subject: [PATCH 05/78] As suggested by WSH032 --- tagger/preset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tagger/preset.py b/tagger/preset.py index 714aa13..7a132d9 100644 --- a/tagger/preset.py +++ b/tagger/preset.py @@ -6,8 +6,6 @@ from pathlib import Path from gradio.context import Context from modules.images import sanitize_filename_part # pylint: disable=E0401 -from modules.paths import extensions_dir -from modules import scripts PresetDict = Dict[str, Dict[str, any]] @@ -110,6 +108,4 @@ def list(self) -> List[str]: return presets -preset = Preset(Path( - os.path.join(extensions_dir, 'stable-diffusion-webui-wd14-tagger/presets') -)) +preset = Preset(Path(__file__).parent.parent.joinpath('presets')) From 50203aa15ebe4af4545f3e3a750ab4cc6bb06591 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 25 Jul 2023 22:20:10 +0200 Subject: [PATCH 06/78] As WSH032 mentioned, this is already caught. --- tagger/interrogator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index a48e363..c7a9634 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -518,8 +518,7 @@ def download(self) -> Tuple[str, str]: self.hf_params['filename'] = filen paths[i] = hf_hub_download(**self.hf_params) except Exception as err: - if str(err)[:25] != "Offline mode is enabled.": - print(f"hf_hub_download({self.hf_params}: {err}") + print(f"hf_hub_download({self.hf_params}: {err}") return paths # write the repo_specs to a json alongside the model so we can From 543fd6e2d3add6a5574650d6c00d88f33cd67dfc Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 25 Jul 2023 22:40:27 +0200 Subject: [PATCH 07/78] This was a bug --- tagger/interrogator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index c7a9634..690fd5a 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -471,7 +471,7 @@ def __init__( elif arg in signature.parameters: try: - tp = signature.parameters[arg].annotation(val) + tp = signature.parameters[arg].annotation self.hf_params[arg] = tp(val) except TypeError: From 3122bfd85560981e4e8cddeb2a00d79a03eb169d Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 25 Jul 2023 22:46:08 +0200 Subject: [PATCH 08/78] allow json to override settings --- tagger/interrogator.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 690fd5a..d87b21e 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -447,22 +447,27 @@ def __init__( # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse # and pass only the supported args. + signature = inspect.signature(hf_hub_download) self.repo_specs = {'repo_id', 'revision', 'library_name', 'library_version'} self.hf_params = {} - for k in self.repo_specs: - if k in kwargs: - self.hf_params[k] = kwargs[k] + for k in kwargs: + if k in signature.parameters: + tp = signature.parameters[k].annotation + if isinstance(kwargs[k], tp): + self.hf_params[k] = kwargs[k] + continue + print(f"Warning: interrogators.json: model {self.name}: " + f"parameter {k} unsupported or or wrong type.") if 'repo_id' not in self.hf_params: - print(f"Error: interrogatos.json: HuggingFace model {self.name} " + print(f"Error: interrogators.json: HuggingFace model {self.name} " "lacks a repo_id. If not already local, download may fail.") attrs = getattr(shared.opts, 'tagger_hf_hub_down_opts', f'cache_dir="{Its.hf_cache}"') attrs = [attr.split('=') for attr in map(str.strip, attrs.split(','))] - signature = inspect.signature(hf_hub_download) for arg, val in attrs: if arg == 'filename' or arg in self.repo_specs: From 9c6df86b910a7c788c93cded6c31dbe20f3aec0e Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 5 Aug 2023 13:26:56 +0200 Subject: [PATCH 09/78] broken --- defaults/interrogators.json | 67 +++++++++++++++++++++++++++++ interrogators.json | 69 ------------------------------ json_schema/db_json_v1_schema.json | 40 ++++++++--------- preload.py | 2 + tagger/interrogator.py | 45 +++++++++++-------- tagger/uiset.py | 8 ++-- 6 files changed, 120 insertions(+), 111 deletions(-) create mode 100644 defaults/interrogators.json delete mode 100644 interrogators.json diff --git a/defaults/interrogators.json b/defaults/interrogators.json new file mode 100644 index 0000000..cdd1bdd --- /dev/null +++ b/defaults/interrogators.json @@ -0,0 +1,67 @@ +{ + "MLDanbooruInterrogator": { + "mld-caformer.dec-5-97527" : { + "model_path" : "ml_caformer_m36_dec-5-97527.onnx", + "name" : "ML-Danbooru Caformer dec-5-97527", + "repo_id" : "deepghs/ml-danbooru-onnx" + }, + "mld-tresnetd.6-30000" : { + "model_path" : "TResnet-D-FLq_ema_6-30000.onnx", + "name" : "ML-Danbooru TResNet-D 6-30000", + "repo_id" : "deepghs/ml-danbooru-onnx" + } + }, + "WaifuDiffusionInterrogator": { + "wd-v1-4-moat-tagger.v2" : { + "name" : "WD14 moat tagger v2", + "repo_id" : "SmilingWolf/wd-v1-4-moat-tagger-v2" + }, + "wd14-convnext.v1" : { + "name" : "WD14 ConvNeXT v1", + "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger" + }, + "wd14-convnext.v2" : { + "name" : "WD14 ConvNeXT v2", + "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger-v2" + }, + "wd14-convnextv2.v1" : { + "name" : "WD14 ConvNeXTV2 v1", + "repo_id" : "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" + }, + "wd14-swinv2-v1" : { + "name" : "WD14 SwinV2 v1", + "repo_id" : "SmilingWolf/wd-v1-4-swinv2-tagger-v2" + }, + "wd14-vit.v1" : { + "name" : "WD14 ViT v1", + "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger" + }, + "wd14-vit.v2" : { + "name" : "WD14 ViT v2", + "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger-v2" + } + }, + "DeepDanbooruInterrogator": { + "deepdanbooru-v3-20211112-sgd-e28": { + "name": "DeepDanbooru v3 20211112 sgd e28", + "repo_id": "KichangKim/DeepDanbooru", + "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip" + }, + "deepdanbooru-v4-20200814-sgd-e30": { + "name": "DeepDanbooru v4 20200814 sgd e30", + "repo_id": "KichangKim/DeepDanbooru", + "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip" + }, + "MLDanbooruInterrogator": { + "mld-caformer.dec-5-97527" : { + "model_path" : "ml_caformer_m36_dec-5-97527.onnx", + "name" : "ML-Danbooru Caformer dec-5-97527", + "repo_id" : "deepghs/ml-danbooru-onnx" + }, + "mld-tresnetd.6-30000" : { + "model_path" : "TResnet-D-FLq_ema_6-30000.onnx", + "name" : "ML-Danbooru TResNet-D 6-30000", + "repo_id" : "deepghs/ml-danbooru-onnx" + } + } +} diff --git a/interrogators.json b/interrogators.json deleted file mode 100644 index ce55031..0000000 --- a/interrogators.json +++ /dev/null @@ -1,69 +0,0 @@ -{ - "mld-caformer.dec-5-97527" : { - "class" : "MLDanbooruInterrogator", - "repo_specs" : { - "model_path" : "ml_caformer_m36_dec-5-97527.onnx", - "name" : "ML-Danbooru Caformer dec-5-97527", - "repo_id" : "deepghs/ml-danbooru-onnx" - } - }, - "mld-tresnetd.6-30000" : { - "class" : "MLDanbooruInterrogator", - "repo_specs" : { - "model_path" : "TResnet-D-FLq_ema_6-30000.onnx", - "name" : "ML-Danbooru TResNet-D 6-30000", - "repo_id" : "deepghs/ml-danbooru-onnx" - } - }, - "wd-v1-4-moat-tagger.v2" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 moat tagger v2", - "repo_id" : "SmilingWolf/wd-v1-4-moat-tagger-v2" - } - }, - "wd14-convnext.v1" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ConvNeXT v1", - "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger" - } - }, - "wd14-convnext.v2" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ConvNeXT v2", - "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger-v2" - } - }, - "wd14-convnextv2.v1" : { - "remark" : "the repo_id name is misleading, but it's v1", - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ConvNeXTV2 v1", - "repo_id" : "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" - } - }, - "wd14-swinv2-v1" : { - "remark" : "the repo_id name is misleading, but it's v1", - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 SwinV2 v1", - "repo_id" : "SmilingWolf/wd-v1-4-swinv2-tagger-v2" - } - }, - "wd14-vit.v1" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ViT v1", - "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger" - } - }, - "wd14-vit.v2" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ViT v2", - "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger-v2" - } - } -} diff --git a/json_schema/db_json_v1_schema.json b/json_schema/db_json_v1_schema.json index 31e2696..c36aee4 100644 --- a/json_schema/db_json_v1_schema.json +++ b/json_schema/db_json_v1_schema.json @@ -10,28 +10,28 @@ "type": "array", "prefixItems": [ {"type": "string" }, - {"type": "number", "minimum": 0} - ], - "minContains": 2, - "maxContains": 2 - } - } - }, - "meta": { - "type": "object", - "properties": { - "index_shift": { - "type": "integer", - "minimum": 0, - "maximum": 16 + {"type": "number", "minimum": 0} + ], + "minContains": 2, + "maxContains": 2 } + } + }, + "meta": { + "type": "object", + "properties": { + "index_shift": { + "type": "integer", + "minimum": 0, + "maximum": 16 } - }, - "add": { "type": "string" }, - "exclude": { "type": "string" }, - "keep": { "type": "string" }, - "repl": { "type": "string" }, - "search": { "type": "string" } + } + }, + "add": { "type": "string" }, + "exclude": { "type": "string" }, + "keep": { "type": "string" }, + "repl": { "type": "string" }, + "search": { "type": "string" } }, "required": ["rating", "tag", "query"], "additionalProperties": false, diff --git a/preload.py b/preload.py index dbf6ad0..b21cba0 100644 --- a/preload.py +++ b/preload.py @@ -4,6 +4,8 @@ from modules.shared import models_path # pylint: disable=import-error +root_dir = Path(__file__).parent.parent + default_ddp_path = Path(models_path, 'deepdanbooru') default_onnx_path = Path(models_path, 'TaggerOnnx') diff --git a/tagger/interrogator.py b/tagger/interrogator.py index d87b21e..455d0f0 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -3,6 +3,7 @@ from pathlib import Path import io import json +from jsonschema import validate, ValidationError import inspect from platform import uname from typing import Tuple, List, Dict, Callable @@ -14,7 +15,7 @@ from modules.paths import extensions_dir from modules import shared -from preload import default_ddp_path, default_onnx_path +from preload import default_ddp_path, default_onnx_path, root_dir from tagger import settings # pylint: disable=import-error from tagger.uiset import QData, IOData # pylint: disable=import-error from . import dbimutils # pylint: disable=import-error # noqa @@ -97,29 +98,34 @@ def setter(val) -> Tuple[str, str]: def refresh(cls) -> List[str]: """Refreshes the interrogator entries""" if len(cls.entries) == 0: - it_path = Path(os.path.join( - extensions_dir, - 'stable-diffusion-webui-wd14-tagger/interrogators.json' - )) + it_path = root_dir.joinpath("interrogators.json") if not it_path.exists(): - raise FileNotFoundError(f'{it_path} not found.') + it_path = root_dir.joinpath("default/interrogators.json") + if not it_path.exists(): + raise FileNotFoundError(f'{it_path} not found.') - with open(it_path) as filename: - raw = json.load(filename) + raw = json.loads(it_path) + schema = root_dir.joinpath('json_schema', + 'interrogators_v1_schema.json') + validate(raw, json.loads(schema.read_text())) - for name, it in raw.items(): - if it["class"] == "DeepDanbooruInterrogator": + for class_name, it in raw.items(): + if class_name == "DeepDanbooruInterrogator": It_type = DeepDanbooruInterrogator - elif it["class"] == "WaifuDiffusionInterrogator": + elif class_name == "WaifuDiffusionInterrogator": It_type = WaifuDiffusionInterrogator - elif it["class"] == "MLDanbooruInterrogator": + elif class_name == "MLDanbooruInterrogator": It_type = MLDanbooruInterrogator else: raise ValueError(f'Unimplemented: {it["class"]}') + for name, obj in it.items(): + if name not in obj: + obj[name] = name + cls.entries[name] = It_type(**obj) cls.entries[name] = It_type(**it["repo_specs"]) - # load deepdanbooru project + # load deepdanbooru project ddp_path = getattr(shared.cmd_opts, 'deepdanbooru_projects_path', default_ddp_path) onnx_path = getattr(shared.cmd_opts, 'onnxtagger_path', @@ -138,6 +144,7 @@ def refresh(cls) -> List[str]: continue cls.entries[path.name] = DeepDanbooruInterrogator(path.name, path) + # XXX: local_path is not set, bug. # scan for onnx models as well for path in os.scandir(onnx_path): print(f"Scanning {path} as onnx model") @@ -180,6 +187,7 @@ def tag_select_csvs_up_front(k): raise NotImplementedError(f"Add {path.name} resolution " "similar to above here") + print(f"Found {path.name} onnx model {local_path} with tags ") cls.entries[path.name].local_model = str(local_path) cls.entries[path.name].local_tags = str(tags_path) @@ -205,8 +213,8 @@ def __init__(self, name: str) -> None: # run_mode 0 is dry run, 1 means run (alternating), 2 means disabled self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2 # default path if not overridden by download - self.local_model = None - self.local_tags = None + self.local_model = '' + self.local_tags = '' # XXX don't Interrogator.refresh()-ception here def load(self) -> bool: @@ -461,8 +469,8 @@ def __init__( f"parameter {k} unsupported or or wrong type.") if 'repo_id' not in self.hf_params: - print(f"Error: interrogators.json: HuggingFace model {self.name} " - "lacks a repo_id. If not already local, download may fail.") + print(f"Warning: interrogators.json: HuggingFace model {self.name}" + " lacks a repo_id. If not already local, download may fail.") attrs = getattr(shared.opts, 'tagger_hf_hub_down_opts', f'cache_dir="{Its.hf_cache}"') @@ -496,7 +504,8 @@ def __init__( def download(self) -> Tuple[str, str]: repo_id = self.hf_params.get('repo_id', '(?)') print(f"Loading {self.name} model file from {repo_id}") - + if self.local_model == '': + Interrogator.refresh() paths = [self.local_model, self.local_tags] data = {} diff --git a/tagger/uiset.py b/tagger/uiset.py index ddb448c..89a400b 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -7,7 +7,7 @@ from math import ceil from hashlib import sha256 from re import compile as re_comp, sub as re_sub, match as re_match, IGNORECASE -from json import dumps, loads, JSONDecodeError +from json import dumps, loads from jsonschema import validate, ValidationError from functools import partial from collections import defaultdict @@ -17,6 +17,7 @@ from modules.deepbooru import re_special # pylint: disable=import-error from tagger import format as tags_format # pylint: disable=import-error from tagger import settings # pylint: disable=import-error +from preload import root_dir Its = settings.InterrogatorSettings @@ -416,9 +417,8 @@ def read_json(cls, outdir) -> None: # validate json using either json_schema/db_jon_v1_schema.json # or json_schema/db_jon_v2_schema.json - schema = Path(__file__).parent.parent.joinpath( - 'json_schema', 'db_json_v1_schema.json' - ) + schema = root_dir.joinpath('json_schema', + 'db_json_v1_schema.json') try: data = loads(cls.json_db.read_text()) validate(data, loads(schema.read_text())) From fe9f3fa5188f18f439cf30ee0b3aae299e8c5b9d Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 6 Aug 2023 10:00:32 +0200 Subject: [PATCH 10/78] Style is deprecated --- tagger/ui.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tagger/ui.py b/tagger/ui.py index b4c353d..43d79e7 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -168,7 +168,7 @@ def on_ui_tabs(): tag_input = {} with gr.Blocks(analytics_enabled=False) as tagger_interface: - with gr.Row().style(equal_height=False): + with gr.Row(equal_height=False): with gr.Column(variant='panel'): # input components @@ -383,7 +383,6 @@ def on_ui_tabs(): gallery = gr.Gallery( label='Gallery', elem_id='gallery', - ).style( columns=[2], rows=[8], object_fit="contain", From 1845c72510427fb774c909d23ac24cd659dd61c2 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 6 Aug 2023 13:17:49 +0200 Subject: [PATCH 11/78] improve wording, this is just a warning --- tagger/interrogator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index a94fe30..f28be1f 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -76,7 +76,7 @@ def get_errors() -> str: # write errors in html pointer list, every error in a
  • tag errors = IOData.error_msg() if len(QData.err) > 0: - errors += 'Fix to write correct output:
    • ' + \ + errors += 'Possible issues:
      • ' + \ '
      • '.join(QData.err) + '
      ' return errors From 1f7ef93557b58101820f3ccb7df7dcec98b13139 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 6 Aug 2023 13:18:51 +0200 Subject: [PATCH 12/78] cleanups --- defaults/interrogators.json | 4 ++-- install.py | 7 ++++--- tagger/interrogator.py | 3 +-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/defaults/interrogators.json b/defaults/interrogators.json index cdd1bdd..a6b71af 100644 --- a/defaults/interrogators.json +++ b/defaults/interrogators.json @@ -45,12 +45,12 @@ "deepdanbooru-v3-20211112-sgd-e28": { "name": "DeepDanbooru v3 20211112 sgd e28", "repo_id": "KichangKim/DeepDanbooru", - "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip" + "filename": "https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip" }, "deepdanbooru-v4-20200814-sgd-e30": { "name": "DeepDanbooru v4 20200814 sgd e30", "repo_id": "KichangKim/DeepDanbooru", - "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip" + "filename": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip" }, "MLDanbooruInterrogator": { "mld-caformer.dec-5-97527" : { diff --git a/install.py b/install.py index 37408bf..411df64 100644 --- a/install.py +++ b/install.py @@ -1,12 +1,13 @@ """Install requirements for WD14-tagger.""" import os import sys - +import json from launch import run # pylint: disable=import-error +local_dir = os.path.dirname(os.path.realpath(__file__)) + NAME = "WD14-tagger" -req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), - "requirements.txt") +req_file = os.path.join(local_dir, "requirements.txt") print(f"loading {NAME} reqs from {req_file}") run(f'"{sys.executable}" -m pip install -q -r "{req_file}"', f"Checking {NAME} requirements.", diff --git a/tagger/interrogator.py b/tagger/interrogator.py index f28be1f..edf67ac 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -3,7 +3,7 @@ from pathlib import Path import io import json -from jsonschema import validate, ValidationError +from jsonschema import validate import inspect from platform import uname from typing import Tuple, List, Dict, Callable @@ -13,7 +13,6 @@ from tqdm import tqdm from huggingface_hub import hf_hub_download -from modules.paths import extensions_dir from modules import shared from preload import default_ddp_path, default_onnx_path, root_dir from tagger import settings # pylint: disable=import-error From 07cd9bdd302c0f857c4defeee3f305258fb579eb Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 12 Aug 2023 19:46:54 +0200 Subject: [PATCH 13/78] gvi --- tagger/interrogator.py | 187 +++++++++++++++++++++-------------------- tagger/settings.py | 16 ++-- tagger/utils.py | 5 +- 3 files changed, 106 insertions(+), 102 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 9e56a2d..692e3ac 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -5,7 +5,7 @@ import json import inspect from re import match as re_match -from platform import system, uname +from platform import uname from typing import Tuple, List, Dict, Callable from pandas import read_csv from PIL import Image, UnidentifiedImageError @@ -13,7 +13,6 @@ from tqdm import tqdm from huggingface_hub import hf_hub_download -from modules.paths import extensions_dir from modules import shared from tagger import settings # pylint: disable=import-error from tagger.uiset import QData, IOData # pylint: disable=import-error @@ -113,7 +112,7 @@ def __init__(self, name: str) -> None: # run_mode 0 is dry run, 1 means run (alternating), 2 means disabled self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2 - def load(self): + def load(self) -> bool: raise NotImplementedError() def large_batch_interrogate(self, images: List, dry_run=False) -> str: @@ -299,7 +298,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} import deepdanbooru.data as ddd @@ -333,34 +333,69 @@ def large_batch_interrogate(self, images: List, dry_run=False) -> str: raise NotImplementedError() -# FIXME this is silly, in what scenario would the env change from MacOS to -# another OS? TODO: remove if the author does not respond. -def get_onnxrt(): - try: - import onnxruntime - return onnxruntime - except ImportError: - # only one of these packages should be installed at one time in an env - # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime - # TODO: remove old package when the environment changes? - from launch import is_installed, run_pip - if not is_installed('onnxruntime'): - if system() == "Darwin": - package_name = "onnxruntime-silicon" +class HFInterrogator(Interrogator): + """ Interrogator for HuggingFace models """ + def __init__( + self, + name: str, + repo_id: str, + model_path: str, + tags_path: str, + ) -> None: + super().__init__(name) + self.repo_id = repo_id + self.model_path = model_path + self.tags_path = tags_path + self.model = None + self.local_model = None + self.local_tags = None + # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse + # and pass only the supported args. + + attrs = getattr(shared.opts, 'tagger_hf_hub_down_opts', + f'cache_dir="{Its.hf_cache}"') + attrs = [attr.split('=') for attr in map(str.strip, attrs.split(','))] + + signature = inspect.signature(hf_hub_download) + self.params = {} + for arg, val in attrs: + if arg in signature.parameters: + try: + tp = signature.parameters[arg].annotation(val) + self.params[arg] = tp(val) + except TypeError: + # unions, used for str of PathLike + if val[0] == val[-1] and val[0] in "'\"": + val = val[1:-1] + self.params[arg] = str(val) else: - package_name = "onnxruntime-gpu" - package = os.environ.get( - 'ONNXRUNTIME_PACKAGE', - package_name - ) + print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " + "Invalid for hf_hub_download() => ignored.") - run_pip(f'install {package}', 'onnxruntime') + def download(self) -> Tuple[str, str]: + print(f"Loading {self.name} model file from {self.repo_id}") + self.params['repo_id'] = self.repo_id - import onnxruntime - return onnxruntime + paths = [self.local_model, self.local_tags] + try: + for i, filename in enumerate([self.model_path, self.tags_path]): + self.params['filename'] = filename + paths[i] = hf_hub_download(**self.params) + except Exception as err: + if str(err)[:25] != "Offline mode is enabled.": + print(f"hf_hub_download({self.params}: {err}") -class WaifuDiffusionInterrogator(Interrogator): + return paths + + def load_model(self, model_path) -> None: + import onnxruntime + self.model = onnxruntime.InferenceSession(model_path, + providers=onnxrt_providers) + print(f'Loaded {self.name} model from {model_path}') + + +class WaifuDiffusionInterrogator(HFInterrogator): """ Interrogator for Waifu Diffusion models """ def __init__( self, @@ -368,43 +403,17 @@ def __init__( model_path='model.onnx', tags_path='selected_tags.csv', repo_id=None, - is_hf=True, ) -> None: - super().__init__(name) - self.repo_id = repo_id - self.model_path = model_path - self.tags_path = tags_path - self.tags = None - self.model = None + super().__init__(name, repo_id, model_path, tags_path) self.tags = None - self.local_model = None - self.local_tags = None - self.is_hf = is_hf - - def download(self) -> None: - mdir = Path(shared.models_path, 'interrogators') - if self.is_hf: - cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) - print(f"Loading {self.name} model file from {self.repo_id}, " - f"{self.model_path}") - - model_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.model_path, - cache_dir=cache) - tags_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.tags_path, - cache_dir=cache) - else: - model_path = self.local_model - tags_path = self.local_tags + def update_model_json(self, model_path, tags_path): download_model = { 'name': self.name, 'model_path': model_path, 'tags_path': tags_path, } + mdir = Path(shared.models_path, 'interrogators') mpath = Path(mdir, 'model.json') data = [download_model] @@ -425,16 +434,22 @@ def download(self) -> None: with io.open(mpath, 'w', encoding='utf-8') as filename: json.dump(data, filename) - return model_path, tags_path - def load(self) -> None: + def load(self) -> bool: model_path, tags_path = self.download() - ort = get_onnxrt() - self.model = ort.InferenceSession(model_path, - providers=onnxrt_providers) - print(f'Loaded {self.name} model from {self.repo_id}') + if not os.path.exists(model_path): + print(f'Model path {model_path} not found.') + return False + + if not os.path.exists(tags_path): + print(f'Tags path {tags_path} not found.') + return False + + self.load_model(model_path) + self.update_model_json(model_path, tags_path) self.tags = read_csv(tags_path) + return True def interrogate( self, @@ -445,7 +460,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} # code for converting the image and running the model is taken from the # link below. thanks, SmilingWolf! @@ -537,7 +553,8 @@ def large_batch_interrogate(self, images, dry_run=True) -> None: # init model if not hasattr(self, 'model') or self.model is None: - self.load() + if not self.load(): + return os.environ["TF_XLA_FLAGS"] = '--tf_xla_auto_jit=2 '\ '--tf_xla_cpu_global_jit' @@ -583,7 +600,7 @@ def pred_model(model): del os.environ["TF_XLA_FLAGS"] -class MLDanbooruInterrogator(Interrogator): +class MLDanbooruInterrogator(HFInterrogator): """ Interrogator for the MLDanbooru model. """ def __init__( self, @@ -592,40 +609,25 @@ def __init__( model_path: str, tags_path='classes.json', ) -> None: - super().__init__(name) - self.model_path = model_path - self.tags_path = tags_path - self.repo_id = repo_id + super().__init__(name, repo_id, model_path, tags_path) self.tags = None - self.model = None - - def download(self) -> Tuple[str, str]: - print(f"Loading {self.name} model file from {self.repo_id}") - cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) - model_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.model_path, - cache_dir=cache - ) - tags_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.tags_path, - cache_dir=cache - ) - return model_path, tags_path - - def load(self) -> None: + def load(self) -> bool: model_path, tags_path = self.download() - ort = get_onnxrt() - self.model = ort.InferenceSession(model_path, - providers=onnxrt_providers) - print(f'Loaded {self.name} model from {model_path}') + if not os.path.exists(model_path): + print(f'Model path {model_path} not found.') + return False + + if not os.path.exists(tags_path): + print(f'Tags path {tags_path} not found.') + return False with open(tags_path, 'r', encoding='utf-8') as filen: self.tags = json.load(filen) + return True + def interrogate( self, image: Image @@ -635,7 +637,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} image = dbimutils.fill_transparent(image) image = dbimutils.resize(image, 448) # TODO CUSTOMIZE diff --git a/tagger/settings.py b/tagger/settings.py index 8510468..07b58fe 100644 --- a/tagger/settings.py +++ b/tagger/settings.py @@ -3,14 +3,17 @@ from typing import List from modules import shared # pylint: disable=import-error from gradio import inputs as gr +from huggingface_hub import hf_hub_download # kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400! DEFAULT_KAMOJIS = '0_0, (o)_(o), +_+, +_-, ._., _, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||' # pylint: disable=line-too-long # noqa: E501 DEFAULT_OFF = '[name].[output_extension]' -HF_CACHE = os.environ.get('HF_HOME', os.environ.get('HUGGINGFACE_HUB_CACHE', - str(os.path.join(shared.models_path, 'interrogators')))) +HF_CACHE = os.environ.get( + 'HUGGINGFACE_HUB_CACHE', # defaults to "$HF_HOME/hub" + str(os.path.join(shared.models_path, 'interrogators'))) + def slider_wrapper(value, elem_id, **kwargs): # required or else gradio will throw errors @@ -121,13 +124,12 @@ def on_ui_settings(): section=section, ), ) - # see huggingface_hub guides/manage-cache shared.opts.add_option( - key='tagger_hf_cache_dir', + key='tagger_hf_hub_down_opts', info=shared.OptionInfo( - HF_CACHE, - label='HuggingFace cache directory, ' - 'see huggingface_hub guides/manage-cache', + str(f'cache_dir="{HF_CACHE}"'), + label='HuggingFace parameters, Comma delimited: arg=value, ' + 'see huggingface_hub docs for available or leave alone.', section=section, ), ) diff --git a/tagger/utils.py b/tagger/utils.py index e30101c..4c70afc 100644 --- a/tagger/utils.py +++ b/tagger/utils.py @@ -111,12 +111,11 @@ def tag_select_csvs_up_front(k): if path.name == 'wd-v1-4-convnextv2-tagger-v2': interrogators[path.name] = WaifuDiffusionInterrogator( path.name, - repo_id='SmilingWolf/SW-CV-ModelZoo', - is_hf=False + repo_id='SmilingWolf/SW-CV-ModelZoo' ) elif path.name == 'Z3D-E621-Convnext': interrogators[path.name] = WaifuDiffusionInterrogator( - 'Z3D-E621-Convnext', is_hf=False) + 'Z3D-E621-Convnext') else: raise NotImplementedError(f"Add {path.name} resolution similar" "to above here") From 77797f1bc0da531245942d65026f813b8a46e229 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 23 Jul 2023 21:55:10 +0200 Subject: [PATCH 14/78] a little more work is required --- tagger/interrogator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 692e3ac..31e07ac 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -364,10 +364,15 @@ def __init__( tp = signature.parameters[arg].annotation(val) self.params[arg] = tp(val) except TypeError: - # unions, used for str of PathLike - if val[0] == val[-1] and val[0] in "'\"": - val = val[1:-1] - self.params[arg] = str(val) + if val == 'None': + self.params[arg] = None + elif arg == 'token' and val in {'True', 'False'}: + self.params[arg] = val == 'True' + else: + # unions, used for str or PathLike + if val[0] == val[-1] and val[0] in "'\"": + val = val[1:-1] + self.params[arg] = str(val) else: print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " "Invalid for hf_hub_download() => ignored.") From 6d81c4c2f4f7956fecd7b8439d86665b813198e9 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 12 Aug 2023 19:53:57 +0200 Subject: [PATCH 15/78] add interrogators.json move refresh to interrogator as a static, and pick up the configured interrogators there. presets in tagger/presets.py and tagger/utils.py can go. write info alongside model so we can check its up to date status # Conflicts: # tagger/ui.py # tagger/utils.py # Conflicts: # tagger/interrogator.py # tagger/ui.py --- .gitignore | 4 +- interrogators.json | 69 ++++++++++++++++ tagger/api.py | 10 +-- tagger/interrogator.py | 183 +++++++++++++++++++++++++++++++++++------ tagger/preset.py | 7 ++ tagger/ui.py | 36 ++++---- tagger/utils.py | 130 ----------------------------- 7 files changed, 258 insertions(+), 181 deletions(-) create mode 100644 interrogators.json delete mode 100644 tagger/utils.py diff --git a/.gitignore b/.gitignore index fd6106c..8ea680e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,5 @@ __pycache__/ .vscode/ .venv/ .env - -presets/ \ No newline at end of file +presets/ +addons/ diff --git a/interrogators.json b/interrogators.json new file mode 100644 index 0000000..ce55031 --- /dev/null +++ b/interrogators.json @@ -0,0 +1,69 @@ +{ + "mld-caformer.dec-5-97527" : { + "class" : "MLDanbooruInterrogator", + "repo_specs" : { + "model_path" : "ml_caformer_m36_dec-5-97527.onnx", + "name" : "ML-Danbooru Caformer dec-5-97527", + "repo_id" : "deepghs/ml-danbooru-onnx" + } + }, + "mld-tresnetd.6-30000" : { + "class" : "MLDanbooruInterrogator", + "repo_specs" : { + "model_path" : "TResnet-D-FLq_ema_6-30000.onnx", + "name" : "ML-Danbooru TResNet-D 6-30000", + "repo_id" : "deepghs/ml-danbooru-onnx" + } + }, + "wd-v1-4-moat-tagger.v2" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 moat tagger v2", + "repo_id" : "SmilingWolf/wd-v1-4-moat-tagger-v2" + } + }, + "wd14-convnext.v1" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ConvNeXT v1", + "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger" + } + }, + "wd14-convnext.v2" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ConvNeXT v2", + "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger-v2" + } + }, + "wd14-convnextv2.v1" : { + "remark" : "the repo_id name is misleading, but it's v1", + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ConvNeXTV2 v1", + "repo_id" : "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" + } + }, + "wd14-swinv2-v1" : { + "remark" : "the repo_id name is misleading, but it's v1", + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 SwinV2 v1", + "repo_id" : "SmilingWolf/wd-v1-4-swinv2-tagger-v2" + } + }, + "wd14-vit.v1" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ViT v1", + "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger" + } + }, + "wd14-vit.v2" : { + "class" : "WaifuDiffusionInterrogator", + "repo_specs" : { + "name" : "WD14 ViT v2", + "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger-v2" + } + } +} diff --git a/tagger/api.py b/tagger/api.py index c50de70..e64ee30 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -9,9 +9,9 @@ from fastapi import FastAPI, Depends, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials -from tagger import utils # pylint: disable=import-error from tagger import api_models as models # pylint: disable=import-error from tagger.uiset import QData # pylint: disable=import-error +from tagger.interrogator import Interrogator class Api: @@ -78,11 +78,11 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if req.image is None: raise HTTPException(404, 'Image not found') - if req.model not in utils.interrogators.keys(): + if req.model not in Interrogator.entries.keys(): raise HTTPException(404, 'Model not found') image = decode_base64_to_image(req.image) - interrogator = utils.interrogators[req.model] + interrogator = Interrogator.entries[req.model] with self.queue_lock: QData.tags.clear() @@ -102,13 +102,13 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): def endpoint_interrogators(self): return models.InterrogatorsResponse( - models=list(utils.interrogators.keys()) + models=list(Interrogator.entries.keys()) ) def endpoint_unload_interrogators(self): unloaded_models = 0 - for i in utils.interrogators.values(): + for i in Interrogator.entries.values(): if i.unload(): unloaded_models = unloaded_models + 1 diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 31e07ac..383a372 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -60,6 +60,7 @@ class Interrogator: } output = None odd_increment = 0 + entries = {} @classmethod def flip(cls, key): @@ -92,6 +93,98 @@ def setter(val) -> Tuple[str, str]: return setter + @classmethod + def refresh(cls) -> List[str]: + """Refreshes the interrogator entries""" + if len(cls.entries) == 0: + it_path = Path(os.path.join( + extensions_dir, + 'stable-diffusion-webui-wd14-tagger/interrogators.json' + )) + if not it_path.exists(): + raise FileNotFoundError(f'{it_path} not found.') + + with open(it_path) as filename: + raw = json.load(filename) + + for name, it in raw.items(): + if it["class"] == "DeepDanbooruInterrogator": + It_type = DeepDanbooruInterrogator + elif it["class"] == "WaifuDiffusionInterrogator": + It_type = WaifuDiffusionInterrogator + elif it["class"] == "MLDanbooruInterrogator": + It_type = MLDanbooruInterrogator + else: + raise ValueError(f'Unimplemented: {it["class"]}') + + cls.entries[name] = It_type(**it["repo_specs"]) + + # load deepdanbooru project + ddp_path = getattr(shared.cmd_opts, 'deepdanbooru_projects_path', + default_ddp_path) + onnx_path = getattr(shared.cmd_opts, 'onnxtagger_path', + default_onnx_path) + os.makedirs(ddp_path, exist_ok=True) + os.makedirs(onnx_path, exist_ok=True) + + for path in os.scandir(ddp_path): + print(f"Scanning {path} as deepdanbooru project") + if not path.is_dir(): + print(f"Warning: {path} is not a directory, skipped") + continue + + if not Path(path, 'project.json').is_file(): + print(f"Warning: {path} has no project.json, skipped") + continue + + cls.entries[path.name] = DeepDanbooruInterrogator(path.name, path) + # scan for onnx models as well + for path in os.scandir(onnx_path): + print(f"Scanning {path} as onnx model") + if not path.is_dir(): + print(f"Warning: {path} is not a directory, skipped") + continue + + onnx_files = [] + for file_name in os.scandir(path): + if file_name.name.endswith('.onnx'): + onnx_files.append(file_name) + + if len(onnx_files) != 1: + print(f"Warning: {path}: multiple .onnx models => skipped") + continue + local_path = Path(path, onnx_files[0].name) + + csv = [x for x in os.scandir(path) if x.name.endswith('.csv')] + if len(csv) == 0: + print(f"Warning: {path}: no selected tags .csv file, skipped") + continue + + def tag_select_csvs_up_front(k): + k = k.name.lower() + return -1 if "tag" in k or "select" in k else 1 + + csv.sort(key=tag_select_csvs_up_front) + tags_path = Path(path, csv[0]) + + if path.name not in cls.entries: + if path.name == 'wd-v1-4-convnextv2-tagger-v2': + cls.entries[path.name] = WaifuDiffusionInterrogator( + path.name, + repo_id='SmilingWolf/SW-CV-ModelZoo' + ) + elif path.name == 'Z3D-E621-Convnext': + cls.entries[path.name] = WaifuDiffusionInterrogator( + 'Z3D-E621-Convnext') + else: + raise NotImplementedError(f"Add {path.name} resolution " + "similar to above here") + + cls.entries[path.name].local_model = str(local_path) + cls.entries[path.name].local_tags = str(tags_path) + + return sorted(i.name for i in cls.entries.values()) + @staticmethod def load_image(path: str) -> Image: try: @@ -111,6 +204,9 @@ def __init__(self, name: str) -> None: self.tags = None # run_mode 0 is dry run, 1 means run (alternating), 2 means disabled self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2 + # default path if not overridden by download + self.local_model = None + self.local_tags = None def load(self) -> bool: raise NotImplementedError() @@ -338,59 +434,96 @@ class HFInterrogator(Interrogator): def __init__( self, name: str, - repo_id: str, model_path: str, tags_path: str, + **kwargs, ) -> None: super().__init__(name) - self.repo_id = repo_id self.model_path = model_path self.tags_path = tags_path self.model = None - self.local_model = None - self.local_tags = None # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse # and pass only the supported args. + self.repo_specs = {'repo_id', 'revision', 'library_name', + 'library_version'} + self.hf_params = {} + for k in self.repo_specs: + if k in kwargs: + self.hf_params[k] = kwargs[k] + + if 'repo_id' not in self.hf_params: + print(f"Error: interrogatos.json: HuggingFace model {self.name} " + "lacks a repo_id. If not already local, download may fail.") + attrs = getattr(shared.opts, 'tagger_hf_hub_down_opts', f'cache_dir="{Its.hf_cache}"') attrs = [attr.split('=') for attr in map(str.strip, attrs.split(','))] signature = inspect.signature(hf_hub_download) - self.params = {} for arg, val in attrs: - if arg in signature.parameters: + if arg == 'filename' or arg in self.repo_specs: + + print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " + "Specific options need to go in the interrogators.json.") + + elif arg in signature.parameters: try: tp = signature.parameters[arg].annotation(val) - self.params[arg] = tp(val) + self.hf_params[arg] = tp(val) + except TypeError: + # unions, used for str or PathLike and a few. if val == 'None': - self.params[arg] = None + self.hf_params[arg] = None elif arg == 'token' and val in {'True', 'False'}: - self.params[arg] = val == 'True' + self.hf_params[arg] = val == 'True' else: - # unions, used for str or PathLike if val[0] == val[-1] and val[0] in "'\"": val = val[1:-1] - self.params[arg] = str(val) + self.hf_params[arg] = str(val) else: print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " "Invalid for hf_hub_download() => ignored.") def download(self) -> Tuple[str, str]: - print(f"Loading {self.name} model file from {self.repo_id}") - self.params['repo_id'] = self.repo_id + repo_id = self.hf_params.get('repo_id', '(?)') + print(f"Loading {self.name} model file from {repo_id}") paths = [self.local_model, self.local_tags] - try: - for i, filename in enumerate([self.model_path, self.tags_path]): - self.params['filename'] = filename - paths[i] = hf_hub_download(**self.params) - except Exception as err: - if str(err)[:25] != "Offline mode is enabled.": - print(f"hf_hub_download({self.params}: {err}") + data = {} + for k in self.repo_specs: + if k in self.hf_params: + data[k] = self.hf_params[k] + + # check if the model is up to date + info_path = Path(self.local_model).with_suffix('.info') + if info_path.exists(): + + if all(os.path.exists(p) for p in paths): + with open(info_path, 'r') as filen: + try: + old_data = json.load(filen) + if old_data == data: + print(f"Model {self.name} is up to date.") + return paths + except json.decoder.JSONDecodeError: + pass + try: + for i, filen in enumerate([self.model_path, self.tags_path]): + self.hf_params['filename'] = filen + paths[i] = hf_hub_download(**self.hf_params) + except Exception as err: + if str(err)[:25] != "Offline mode is enabled.": + print(f"hf_hub_download({self.hf_params}: {err}") + return paths + + # write the repo_specs to a json alongside the model so we can + # check if the model is up to date + with open(info_path, 'w') as filen: + json.dump(data, filen) return paths def load_model(self, model_path) -> None: @@ -407,9 +540,9 @@ def __init__( name: str, model_path='model.onnx', tags_path='selected_tags.csv', - repo_id=None, + **kwargs, ) -> None: - super().__init__(name, repo_id, model_path, tags_path) + super().__init__(name, model_path, tags_path, **kwargs) self.tags = None def update_model_json(self, model_path, tags_path): @@ -610,11 +743,11 @@ class MLDanbooruInterrogator(HFInterrogator): def __init__( self, name: str, - repo_id: str, model_path: str, tags_path='classes.json', + **kwargs ) -> None: - super().__init__(name, repo_id, model_path, tags_path) + super().__init__(name, model_path, tags_path, **kwargs) self.tags = None def load(self) -> bool: @@ -628,6 +761,8 @@ def load(self) -> bool: print(f'Tags path {tags_path} not found.') return False + self.load_model(model_path) + with open(tags_path, 'r', encoding='utf-8') as filen: self.tags = json.load(filen) diff --git a/tagger/preset.py b/tagger/preset.py index 9189535..714aa13 100644 --- a/tagger/preset.py +++ b/tagger/preset.py @@ -6,6 +6,8 @@ from pathlib import Path from gradio.context import Context from modules.images import sanitize_filename_part # pylint: disable=E0401 +from modules.paths import extensions_dir +from modules import scripts PresetDict = Dict[str, Dict[str, any]] @@ -106,3 +108,8 @@ def list(self) -> List[str]: presets.append(self.default_filename) return presets + + +preset = Preset(Path( + os.path.join(extensions_dir, 'stable-diffusion-webui-wd14-tagger/presets') +)) diff --git a/tagger/ui.py b/tagger/ui.py index 8102dc4..cdd2190 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -20,9 +20,9 @@ def tf_version(): from modules.call_queue import wrap_gradio_gpu_call except ImportError: from webui import wrap_gradio_gpu_call # pylint: disable=import-error -from tagger import utils # pylint: disable=import-error from tagger.interrogator import Interrogator as It # pylint: disable=E0401 from tagger.uiset import IOData, QData # pylint: disable=import-error +from tagger.preset import preset TAG_INPUTS = ["add", "keep", "exclude", "search", "replace"] COMMON_OUTPUT = Tuple[ @@ -39,7 +39,7 @@ def unload_interrogators() -> List[str]: unloaded_models = 0 remaining_models = '' - for i in utils.interrogators.values(): + for i in It.entries.values(): if i.unload(): unloaded_models = unloaded_models + 1 elif i.model is not None: @@ -73,7 +73,7 @@ def on_interrogate( getattr(QData, "update_" + part)(val) It.input[part] = val - interrogator: It = next((i for i in utils.interrogators.values() if + interrogator: It = next((i for i in It.entries.values() if i.name == name), None) if interrogator is None: return None, None, None, None, None, f"'{name}': invalid interrogator" @@ -105,7 +105,7 @@ def on_interrogate_image_submit( if image is None: return None, None, None, None, None, 'No image selected' - interrogator: It = next((i for i in utils.interrogators.values() if + interrogator: It = next((i for i in It.entries.values() if i.name == name), None) if interrogator is None: return None, None, None, None, None, f"'{name}': invalid interrogator" @@ -193,13 +193,13 @@ def on_ui_tabs(): ) with gr.TabItem(label='Batch from directory'): - input_glob = utils.preset.component( + input_glob = preset.component( gr.Textbox, value='', label='Input directory - See also settings tab.', placeholder='/path/to/images or to/images/**/*' ) - output_dir = utils.preset.component( + output_dir = preset.component( gr.Textbox, value=It.input["output_dir"], label='Output directory', @@ -213,7 +213,7 @@ def on_ui_tabs(): ) with gr.Row(variant='compact'): with gr.Column(variant='panel'): - large_query = utils.preset.component( + large_query = preset.component( gr.Checkbox, label='huge batch query (TF 2.10, ' 'experimental)', @@ -222,7 +222,7 @@ def on_ui_tabs(): version.parse('2.10') ) with gr.Column(variant='panel'): - save_tags = utils.preset.component( + save_tags = preset.component( gr.Checkbox, label='Save to tags files', value=True @@ -257,12 +257,8 @@ def on_ui_tabs(): ) with gr.Row(variant='compact'): - def refresh(): - utils.refresh_interrogators() - return sorted(x.name for x in utils.interrogators - .values()) - interrogator_names = refresh() - interrogator = utils.preset.component( + interrogator_names = It.refresh() + interrogator = preset.component( gr.Dropdown, label='Interrogator', choices=interrogator_names, @@ -276,7 +272,7 @@ def refresh(): ui.create_refresh_button( interrogator, lambda: None, - lambda: {'choices': refresh()}, + lambda: {'choices': It.refresh()}, 'refresh_interrogator' ) @@ -352,7 +348,7 @@ def refresh(): variant='secondary' ) with gr.Column(variant='compact'): - tag_search_selection = utils.preset.component( + tag_search_selection = preset.component( gr.Textbox, label='Multi string search: part1, part2.. ' '(Enter key to update)', @@ -411,11 +407,11 @@ def refresh(): save_tags.input(fn=IOData.flip_save_tags(), inputs=[], outputs=[]) # Preset and unload buttons - selected_preset.change(fn=utils.preset.apply, inputs=[selected_preset], - outputs=[*utils.preset.components, info]) + selected_preset.change(fn=preset.apply, inputs=[selected_preset], + outputs=[*preset.components, info]) - save_preset_button.click(fn=utils.preset.save, inputs=[selected_preset, - *utils.preset.components], outputs=[info]) + save_preset_button.click(fn=preset.save, inputs=[selected_preset, + *preset.components], outputs=[info]) unload_all_models.click(fn=unload_interrogators, outputs=[info]) diff --git a/tagger/utils.py b/tagger/utils.py deleted file mode 100644 index 4c70afc..0000000 --- a/tagger/utils.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Utility functions for the tagger module""" -import os - -from typing import List, Dict -from pathlib import Path - -from modules import shared, scripts # pylint: disable=import-error -from modules.shared import models_path # pylint: disable=import-error - -default_ddp_path = Path(models_path, 'deepdanbooru') -default_onnx_path = Path(models_path, 'TaggerOnnx') -from tagger.preset import Preset # pylint: disable=import-error -from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, \ - MLDanbooruInterrogator # pylint: disable=E0401 # noqa: E501 -from tagger.interrogator import WaifuDiffusionInterrogator # pylint: disable=E0401 # noqa: E501 - -preset = Preset(Path(scripts.basedir(), 'presets')) - -interrogators: Dict[str, Interrogator] = { - 'wd14-vit.v1': WaifuDiffusionInterrogator( - 'WD14 ViT v1', - repo_id='SmilingWolf/wd-v1-4-vit-tagger' - ), - 'wd14-vit.v2': WaifuDiffusionInterrogator( - 'WD14 ViT v2', - repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2', - ), - 'wd14-convnext.v1': WaifuDiffusionInterrogator( - 'WD14 ConvNeXT v1', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger' - ), - 'wd14-convnext.v2': WaifuDiffusionInterrogator( - 'WD14 ConvNeXT v2', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2', - ), - 'wd14-convnextv2.v1': WaifuDiffusionInterrogator( - 'WD14 ConvNeXTV2 v1', - # the name is misleading, but it's v1 - repo_id='SmilingWolf/wd-v1-4-convnextv2-tagger-v2', - ), - 'wd14-swinv2-v1': WaifuDiffusionInterrogator( - 'WD14 SwinV2 v1', - # again misleading name - repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2', - ), - 'wd-v1-4-moat-tagger.v2': WaifuDiffusionInterrogator( - 'WD14 moat tagger v2', - repo_id='SmilingWolf/wd-v1-4-moat-tagger-v2' - ), - 'mld-caformer.dec-5-97527': MLDanbooruInterrogator( - 'ML-Danbooru Caformer dec-5-97527', - repo_id='deepghs/ml-danbooru-onnx', - model_path='ml_caformer_m36_dec-5-97527.onnx' - ), - 'mld-tresnetd.6-30000': MLDanbooruInterrogator( - 'ML-Danbooru TResNet-D 6-30000', - repo_id='deepghs/ml-danbooru-onnx', - model_path='TResnet-D-FLq_ema_6-30000.onnx' - ), -} - - -def refresh_interrogators() -> List[str]: - """Refreshes the interrogators list""" - # load deepdanbooru project - ddp_path = shared.cmd_opts.deepdanbooru_projects_path - if ddp_path is None: - ddp_path = default_ddp_path - onnx_path = shared.cmd_opts.onnxtagger_path - if onnx_path is None: - onnx_path = default_onnx_path - os.makedirs(ddp_path, exist_ok=True) - os.makedirs(onnx_path, exist_ok=True) - - for path in os.scandir(ddp_path): - print(f"Scanning {path} as deepdanbooru project") - if not path.is_dir(): - print(f"Warning: {path} is not a directory, skipped") - continue - - if not Path(path, 'project.json').is_file(): - print(f"Warning: {path} has no project.json, skipped") - continue - - interrogators[path.name] = DeepDanbooruInterrogator(path.name, path) - # scan for onnx models as well - for path in os.scandir(onnx_path): - print(f"Scanning {path} as onnx model") - if not path.is_dir(): - print(f"Warning: {path} is not a directory, skipped") - continue - - onnx_files = [x for x in os.scandir(path) if x.name.endswith('.onnx')] - if len(onnx_files) != 1: - print(f"Warning: {path} requires exactly one .onnx model, skipped") - continue - local_path = Path(path, onnx_files[0].name) - - csv = [x for x in os.scandir(path) if x.name.endswith('.csv')] - if len(csv) == 0: - print(f"Warning: {path} has no selected tags .csv file, skipped") - continue - - def tag_select_csvs_up_front(k): - sum(-1 if t in k.name.lower() else 1 for t in ["tag", "select"]) - - csv.sort(key=tag_select_csvs_up_front) - tags_path = Path(path, csv[0]) - - if path.name not in interrogators: - if path.name == 'wd-v1-4-convnextv2-tagger-v2': - interrogators[path.name] = WaifuDiffusionInterrogator( - path.name, - repo_id='SmilingWolf/SW-CV-ModelZoo' - ) - elif path.name == 'Z3D-E621-Convnext': - interrogators[path.name] = WaifuDiffusionInterrogator( - 'Z3D-E621-Convnext') - else: - raise NotImplementedError(f"Add {path.name} resolution similar" - "to above here") - - interrogators[path.name].local_model = str(local_path) - interrogators[path.name].local_tags = str(tags_path) - - return sorted(interrogators.keys()) - - -def split_str(string: str, separator=',') -> List[str]: - return [x.strip() for x in string.split(separator) if x] From f7f8f1871936d596ba5bb98138b66447991db69f Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Thu, 10 Aug 2023 20:00:54 +0200 Subject: [PATCH 16/78] fix again --- tagger/interrogator.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 383a372..3282977 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -120,10 +120,12 @@ def refresh(cls) -> List[str]: cls.entries[name] = It_type(**it["repo_specs"]) # load deepdanbooru project - ddp_path = getattr(shared.cmd_opts, 'deepdanbooru_projects_path', - default_ddp_path) - onnx_path = getattr(shared.cmd_opts, 'onnxtagger_path', - default_onnx_path) + ddp_path = shared.cmd_opts.deepdanbooru_projects_path + if ddp_path is None: + ddp_path = Path(shared.models_path, 'deepdanbooru') + onnx_path = shared.cmd_opts.onnx_path + if onnx_path is None: + onnx_path = Path(shared.models_path, 'TaggerOnnx') os.makedirs(ddp_path, exist_ok=True) os.makedirs(onnx_path, exist_ok=True) From 39d4fd2a91bca59de369747e69bc13edea22e79e Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 12 Aug 2023 14:54:08 -0400 Subject: [PATCH 17/78] Actually use `--additional-device-ids` arg --- tagger/interrogator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index df32a56..1ae5cc2 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -30,13 +30,13 @@ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958 onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] -if shared.cmd_opts.device_id is not None: - m = re_match(r'([cg])pu:\d+$', shared.cmd_opts.device_id) +if shared.cmd_opts.additional_device_ids is not None: + m = re_match(r'([cg])pu:\d+$', shared.cmd_opts.additional_device_ids) if m is None: raise ValueError('--device-id is not cpu: or gpu:') if m.group(1) == 'c': onnxrt_providers.pop(0) - TF_DEVICE_NAME = f'/{shared.cmd_opts.device_id}' + TF_DEVICE_NAME = f'/{shared.cmd_opts.additional_device_ids}' elif use_cpu: TF_DEVICE_NAME = '/cpu:0' onnxrt_providers.pop(0) From ae8345bf87372b62baa812d757100f01b1128f23 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 12 Aug 2023 21:18:30 +0200 Subject: [PATCH 18/78] currently only one device is supported --- preload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/preload.py b/preload.py index ecd0bdb..92ff73f 100644 --- a/preload.py +++ b/preload.py @@ -1,5 +1,4 @@ """ Preload module for DeepDanbooru or onnxtagger. """ -from pathlib import Path from argparse import ArgumentParser root_dir = Path(__file__).parent.parent @@ -20,8 +19,9 @@ def preload(parser: ArgumentParser): type=str, help='Path to directory with Onnyx project(s).' ) + # TODO allow using devices in parallel, specified as comma separed list parser.add_argument( '--additional-device-ids', type=str, - help='Extra device ID to use. cpu:0,gpu:1..', + help='Device ID to use. cpu:0, gpu:0 or gpu:1, etc.', ) From aa92dbacfb778f3249b654e640b31c206cf74adb Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 12 Aug 2023 23:30:48 +0200 Subject: [PATCH 19/78] Revert "improve types and other cleanups" This reverts commit 0c1fd970d96518a7ab4a5ebac38b942b5473d01d. --- tagger/ui.py | 37 +++++++---------- tagger/uiset.py | 107 ++++++++++++++++++++++-------------------------- 2 files changed, 65 insertions(+), 79 deletions(-) diff --git a/tagger/ui.py b/tagger/ui.py index cdd2190..34c4689 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -1,5 +1,5 @@ """ This module contains the ui for the tagger tab. """ -from typing import Dict, Tuple, List, Optional +from typing import Dict, Tuple, List import gradio as gr import re from PIL import Image @@ -26,11 +26,11 @@ def tf_version(): TAG_INPUTS = ["add", "keep", "exclude", "search", "replace"] COMMON_OUTPUT = Tuple[ - Optional[str], # tags as string - Optional[str], # discarded tags as string - Optional[Dict[str, float]], # rating confidences - Optional[Dict[str, float]], # tag confidences - Optional[Dict[str, float]], # excluded tag confidences + str, # tags as string + str, # discarded tags as string + Dict[str, float], # rating confidences + Dict[str, float], # tag confidences + Dict[str, float], # excluded tag confidences str, # error message ] @@ -52,7 +52,7 @@ def unload_interrogators() -> List[str]: "not be unloaded, a known issue." QData.clear(1) - return [f'{unloaded_models} model(s) unloaded{remaining_models}'] + return (f'{unloaded_models} model(s) unloaded{remaining_models}',) def on_interrogate( @@ -90,10 +90,9 @@ def on_interrogate_image(*args) -> COMMON_OUTPUT: # hack brcause image interrogaion occurs twice It.odd_increment = It.odd_increment + 1 if It.odd_increment & 1 == 1: - return (None, None, None, None, None, '') + return (None, None, None, None, None, '') return on_interrogate_image_submit(*args) - def on_interrogate_image_submit( image: Image, name: str, filt: str, *args ) -> COMMON_OUTPUT: @@ -116,7 +115,7 @@ def on_interrogate_image_submit( def move_selection_to_input( filt: str, field: str -) -> Tuple[Optional[str], Optional[str], str]: +) -> Tuple[str, str, str]: """ moves the selected to the input field """ if It.output is None: return (None, None, '') @@ -139,15 +138,11 @@ def move_selection_to_input( return ('', data, info) -def move_selection_to_keep( - tag_search_filter: str -) -> Tuple[Optional[str], Optional[str], str]: +def move_selection_to_keep(tag_search_filter: str) -> Tuple[str, str, str]: return move_selection_to_input(tag_search_filter, "keep") -def move_selection_to_exclude( - tag_search_filter: str -) -> Tuple[Optional[str], Optional[str], str]: +def move_selection_to_exclude(tag_search_filter: str) -> Tuple[str, str, str]: return move_selection_to_input(tag_search_filter, "exclude") @@ -462,13 +457,11 @@ def on_ui_tabs(): [tag_input[tag] for tag in TAG_INPUTS] # interrogation events - image_submit.click( - fn=wrap_gradio_gpu_call(on_interrogate_image_submit), - inputs=[image] + common_input, outputs=common_output) + image_submit.click(fn=wrap_gradio_gpu_call(on_interrogate_image_submit), + inputs=[image] + common_input, outputs=common_output) - image.change( - fn=wrap_gradio_gpu_call(on_interrogate_image), - inputs=[image] + common_input, outputs=common_output) + image.change(fn=wrap_gradio_gpu_call(on_interrogate_image), + inputs=[image] + common_input, outputs=common_output) batch_submit.click(fn=wrap_gradio_gpu_call(on_interrogate), inputs=[input_glob, output_dir] + common_input, diff --git a/tagger/uiset.py b/tagger/uiset.py index e0f2c6d..c9adde6 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -1,6 +1,6 @@ """ for handling ui settings """ -from typing import List, Dict, Tuple, Callable, Set, Optional, Pattern +from typing import List, Dict, Tuple, Callable, Set, Optional import os from pathlib import Path from glob import glob @@ -30,10 +30,10 @@ # interrogator return type ItRetTP = Tuple[ - Optional[Dict[str, float]], # rating confidences - Optional[Dict[str, float]], # tag confidences - Optional[Dict[str, float]], # excluded tag confidences - str, # error message + Dict[str, float], # rating confidences + Dict[str, float], # tag confidences + Dict[str, float], # excluded tag confidences + str, # error message ] @@ -42,12 +42,9 @@ class IOData: last_path_mtimes = None base_dir = None output_root = None - paths: List[Tuple[ - Path, Optional[Path], Optional[Path], Optional[str] - ]] = [] + paths = [] save_tags = True - err: Set[str] = set() - base_dir_last = None + err = set() @classmethod def error_msg(cls) -> str: @@ -55,7 +52,7 @@ def error_msg(cls) -> str: "
    " @classmethod - def flip_save_tags(cls) -> Callable: + def flip_save_tags(cls) -> callable: def toggle(): cls.save_tags = not cls.save_tags return toggle @@ -69,7 +66,7 @@ def update_output_dir(cls, output_dir: str) -> None: """ update output directory, and set input and output paths """ pout = Path(output_dir) if pout != cls.output_root: - paths = [str(x[0]) for x in cls.paths] + paths = [x[0] for x in cls.paths] cls.paths = [] cls.output_root = pout cls.set_batch_io(paths) @@ -85,13 +82,13 @@ def get_hashes(cls) -> Set[str]: """ get hashes of all files """ ret = set() for entries in cls.paths: - if entries[3] is not None: + if len(entries) == 4: ret.add(entries[3]) else: # if there is no checksum, calculate it - image = Image.open(Path(entries[0])) + image = Image.open(entries[0]) checksum = cls.get_bytes_hash(image.tobytes()) - entries = (entries[0], entries[1], entries[2], checksum) + entries.append(checksum) ret.add(checksum) return ret @@ -117,7 +114,7 @@ def update_input_glob(cls, input_glob: str) -> None: return cls.err.discard(msg) - recursive = getattr(shared.opts, 'tagger_batch_recursive', True) + recursive = getattr(shared.opts, 'tagger_batch_recursive', '') path_mtimes = [] for filename in glob(input_glob, recursive=recursive): ext = os.path.splitext(filename)[1].lower() @@ -154,10 +151,10 @@ def set_batch_io(cls, paths: List[str]) -> None: """ set input and output paths for batch mode """ checked_dirs = set() cls.paths = [] - for filename in paths: - path = Path(filename) + for path in paths: + path = Path(path) if not cls.save_tags: - cls.paths.append((path, None, None, None)) + cls.paths.append([path, '', '']) continue # guess the output path @@ -177,61 +174,57 @@ def set_batch_io(cls, paths: List[str]) -> None: except (TypeError, ValueError): cls.err.add(msg) - if cls.output_root is None: - raise ValueError output_dir = cls.output_root.joinpath( *path.parts[base_dir_last_idx + 1:]).parent tags_out = output_dir.joinpath(formatted_output_filename) if output_dir in checked_dirs: - cls.paths.append((path, tags_out, None, None)) + cls.paths.append([path, tags_out, '']) else: checked_dirs.add(output_dir) if os.path.exists(output_dir): msg = 'output_dir: not a directory.' if os.path.isdir(output_dir): - cls.paths.append((path, tags_out, None, None)) + cls.paths.append([path, tags_out, '']) cls.err.discard(msg) else: cls.err.add(msg) else: - cls.paths.append((path, tags_out, output_dir, None)) + cls.paths.append([path, tags_out, output_dir]) class QData: """ Query data: contains parameters for the query """ - add_tags: List[str] = [] - keep_tags: Set[str] = set() - exclude_tags: List[str] = [] - search_tags: Dict[int, Pattern[str]] = {} - replace_tags: List[str] = [] + add_tags = [] + keep_tags = set() + exclude_tags = [] + search_tags = {} + replace_tags = [] threshold = 0.35 tag_frac_threshold = 0.05 # read from db.json, update with what should be written to db.json: json_db = None - weighed: Tuple[ - Dict[str, List[float]], - Dict[str, List[float]] - ] = (defaultdict(list), defaultdict(list)) - query: Dict[str, Tuple[str, int]] = {} + weighed = (defaultdict(list), defaultdict(list)) + query = {} # representing the (cumulative) current interrogations - ratings: Dict[str, float] = defaultdict(float) - tags: Dict[str, List[float]] = defaultdict(list) - discarded_tags: Dict[str, List[float]] = defaultdict(list) - in_db: Dict[ - int, - Tuple[str, str, str, Dict[str, float], Dict[str, float]] - ] = {} - for_tags_file: Dict[ - str, Dict[str, float] - ] = defaultdict(lambda: defaultdict(float)) + ratings = defaultdict(float) + tags = defaultdict(list) + discarded_tags = defaultdict(list) + in_db = {} + for_tags_file = defaultdict(lambda: defaultdict(float)) had_new = False - err: Set[str] = set() - image_dups: Dict[str, Set[str]] = defaultdict(set) + err = set() + image_dups = defaultdict(set) + + @classmethod + def set(cls, key: str) -> Callable[[str], Tuple[str]]: + def setter(val) -> Tuple[str]: + setattr(cls, key, val) + return setter @classmethod def set(cls, key: str) -> Callable[[str], Tuple[str]]: @@ -329,7 +322,7 @@ def update_add(cls, add: str) -> None: shared.opts.tagger_count_threshold = len(cls.add_tags) @staticmethod - def compile_rex(rex: str) -> Optional[Pattern[str]]: + def compile_rex(rex: str) -> Optional: if rex in {'', '^', '$', '^$'}: return None if rex[0] == '^': @@ -382,7 +375,7 @@ def update_replace(cls, replace: str) -> None: cls.err.discard(msg) @classmethod - def get_i_wt(cls, stored: float) -> Tuple[int, float]: + def get_i_wt(cls, stored: int) -> Tuple[int, float]: """ in db.json or QData.weighed, the weights & increment in the list are encoded. Each filestamp-interrogation corresponds to an incrementing @@ -455,8 +448,8 @@ def get_index(cls, fi_key: str, path='') -> int: @classmethod def single_data(cls, fi_key: str) -> None: """ get tags and ratings for filestamp-interrogator """ - index = cls.query[fi_key][1] - data: Tuple[Dict[str, float], Dict[str, float]] = ({}, {}) + index = cls.query.get(fi_key)[1] + data = ({}, {}) for j in range(2): for ent, lst in cls.weighed[j].items(): for i, val in map(cls.get_i_wt, lst): @@ -480,11 +473,11 @@ def correct_tag(cls, tag: str) -> str: if getattr(shared.opts, 'tagger_escape', False): tag = re_special.sub(r'\\\1', tag) # tag_escape_pattern - if len(cls.search_tags) != len(cls.replace_tags): + if len(cls.search_tags) == len(cls.replace_tags): for i, regex in cls.search_tags.items(): - m = re_match(regex, tag) - if m: - return re_sub(regex, cls.replace_tags[i], tag) + if re_match(regex, tag): + tag = re_sub(regex, cls.replace_tags[i], tag) + break return tag @@ -625,10 +618,10 @@ def finalize(cls, count: int) -> ItRetTP: for file, remaining_tags in cls.for_tags_file.items(): sorted_tags = cls.sort_tags(remaining_tags) if weighted_tags_files: - joinable = [f'({k}:{v})' for k, v in sorted_tags] + sorted_tags = [f'({k}:{v})' for k, v in sorted_tags] else: - joinable = [k for k, v in sorted_tags] - Path(file).write_text(', '.join(joinable), encoding='utf-8') + sorted_tags = [k for k, v in sorted_tags] + file.write_text(', '.join(sorted_tags), encoding='utf-8') warn = "" if len(QData.err) > 0: From 63892a5b82da1465bffd50a2dc7953ec5d41c2aa Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 13 Aug 2023 13:19:03 +0200 Subject: [PATCH 20/78] Use a state variable to be able to send the tags via txt2img and friends --- tagger/ui.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tagger/ui.py b/tagger/ui.py index 34c4689..dea92d5 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -27,6 +27,7 @@ def tf_version(): TAG_INPUTS = ["add", "keep", "exclude", "search", "replace"] COMMON_OUTPUT = Tuple[ str, # tags as string + str, # html tags as string str, # discarded tags as string Dict[str, float], # rating confidences Dict[str, float], # tag confidences @@ -156,12 +157,12 @@ def search_filter(filt: str) -> COMMON_OUTPUT: tags = {k: v for k, v in tags.items() if re_part.search(k)} lost = {k: v for k, v in lost.items() if re_part.search(k)} - tags_str = ', '.join(f'{k}' for k, v in tags.items()) - lost_str = ', '.join(f'{k}' for k, v in lost.items()) + h_tags = ', '.join(f'{k}' for k in tags.keys()) + h_lost = ', '.join(f'{k}' for k in lost.keys()) - return (tags_str, lost_str, ratings, tags, lost, info) + return (', '.join(tags.keys()), h_tags, h_lost, ratings, tags, lost, info) def on_ui_tabs(): @@ -351,7 +352,8 @@ def on_ui_tabs(): with gr.Tabs(): with gr.TabItem(label='Ratings and included tags'): # clickable tags to populate excluded tags - tags = gr.HTML( + tags = gr.State(value="") + html_tags = gr.HTML( label='Tags', elem_id='tags', ) @@ -434,7 +436,7 @@ def on_ui_tabs(): tab_gallery.select(fn=on_gallery, inputs=[], outputs=[gallery]) - common_output = [tags, discarded_tags, rating_confidences, + common_output = [tags, html_tags, discarded_tags, rating_confidences, tag_confidences, excluded_tag_confidences, info] # search input textbox From b5d5e6b38c01201b564c93f7447df18ccd1084f9 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 26 Aug 2023 11:47:57 +0200 Subject: [PATCH 21/78] Correct nr of args in error returns # Conflicts: # tagger/ui.py --- tagger/ui.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/tagger/ui.py b/tagger/ui.py index dea92d5..c2c114d 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -1,5 +1,5 @@ """ This module contains the ui for the tagger tab. """ -from typing import Dict, Tuple, List +from typing import Dict, Tuple, List, Optional import gradio as gr import re from PIL import Image @@ -26,17 +26,17 @@ def tf_version(): TAG_INPUTS = ["add", "keep", "exclude", "search", "replace"] COMMON_OUTPUT = Tuple[ - str, # tags as string - str, # html tags as string - str, # discarded tags as string - Dict[str, float], # rating confidences - Dict[str, float], # tag confidences - Dict[str, float], # excluded tag confidences + Optional[str], # tags as string + Optional[str], # html tags as string + Optional[str], # discarded tags as string + Optional[Dict[str, float]], # rating confidences + Optional[Dict[str, float]], # tag confidences + Optional[Dict[str, float]], # excluded tag confidences str, # error message ] -def unload_interrogators() -> List[str]: +def unload_interrogators() -> Tuple[str]: unloaded_models = 0 remaining_models = '' @@ -66,7 +66,7 @@ def on_interrogate( It.input["output_dir"] = output_dir if len(IOData.err) > 0: - return None, None, None, None, None, IOData.error_msg() + return (None,) * 6 + (IOData.error_msg(),) for i, val in enumerate(args): part = TAG_INPUTS[i] @@ -77,7 +77,7 @@ def on_interrogate( interrogator: It = next((i for i in It.entries.values() if i.name == name), None) if interrogator is None: - return None, None, None, None, None, f"'{name}': invalid interrogator" + return (None,) * 6 + (f"'{name}': invalid interrogator",) interrogator.batch_interrogate() return search_filter(filt) @@ -91,9 +91,10 @@ def on_interrogate_image(*args) -> COMMON_OUTPUT: # hack brcause image interrogaion occurs twice It.odd_increment = It.odd_increment + 1 if It.odd_increment & 1 == 1: - return (None, None, None, None, None, '') + return (None,) * 6 + ('',) return on_interrogate_image_submit(*args) + def on_interrogate_image_submit( image: Image, name: str, filt: str, *args ) -> COMMON_OUTPUT: @@ -104,11 +105,11 @@ def on_interrogate_image_submit( It.input[part] = val if image is None: - return None, None, None, None, None, 'No image selected' + return (None,) * 6 + ('No image selected',) interrogator: It = next((i for i in It.entries.values() if i.name == name), None) if interrogator is None: - return None, None, None, None, None, f"'{name}': invalid interrogator" + return (None,) * 6 + (f"'{name}': invalid interrogator",) interrogator.interrogate_image(image) return search_filter(filt) @@ -116,7 +117,7 @@ def on_interrogate_image_submit( def move_selection_to_input( filt: str, field: str -) -> Tuple[str, str, str]: +) -> Tuple[Optional[str], Optional[str], str]: """ moves the selected to the input field """ if It.output is None: return (None, None, '') @@ -139,11 +140,15 @@ def move_selection_to_input( return ('', data, info) -def move_selection_to_keep(tag_search_filter: str) -> Tuple[str, str, str]: +def move_selection_to_keep( + tag_search_filter: str +) -> Tuple[Optional[str], Optional[str], str]: return move_selection_to_input(tag_search_filter, "keep") -def move_selection_to_exclude(tag_search_filter: str) -> Tuple[str, str, str]: +def move_selection_to_exclude( + tag_search_filter: str +) -> Tuple[Optional[str], Optional[str], str]: return move_selection_to_input(tag_search_filter, "exclude") @@ -151,7 +156,7 @@ def search_filter(filt: str) -> COMMON_OUTPUT: """ filters the tags and lost tags for the search field """ ratings, tags, lost, info = It.output if ratings is None: - return (None, None, None, None, None, info) + return (None,) * 6 + (info,) if filt: re_part = re.compile('(' + re.sub(', ?', '|', filt) + ')') tags = {k: v for k, v in tags.items() if re_part.search(k)} From fcbc59b11ab39e88afb2c60280d3e380129df47c Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 19 Aug 2023 08:59:17 +0200 Subject: [PATCH 22/78] recursive glob was failing for directories --- tagger/uiset.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tagger/uiset.py b/tagger/uiset.py index c9adde6..b44abec 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -114,15 +114,16 @@ def update_input_glob(cls, input_glob: str) -> None: return cls.err.discard(msg) - recursive = getattr(shared.opts, 'tagger_batch_recursive', '') + recursive = getattr(shared.opts, 'tagger_batch_recursive', False) path_mtimes = [] for filename in glob(input_glob, recursive=recursive): - ext = os.path.splitext(filename)[1].lower() - if ext in supported_extensions: - path_mtimes.append(os.path.getmtime(filename)) - paths.append(filename) - elif ext != '.txt' and 'db.json' not in filename: - print(f'{filename}: not an image extension: "{ext}"') + if not os.path.isdir(filename): + ext = os.path.splitext(filename)[1].lower() + if ext in supported_extensions: + path_mtimes.append(os.path.getmtime(filename)) + paths.append(filename) + elif ext != '.txt' and 'db.json' not in filename: + print(f'{filename}: not an image extension: "{ext}"') # interrogating in a directory with no pics, still flush the cache if len(path_mtimes) > 0 and cls.last_path_mtimes == path_mtimes: From aa98f057377fef2f5de6bf166ca7a1d0afdf8634 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 19 Aug 2023 09:03:41 +0200 Subject: [PATCH 23/78] actually better to default to True here. --- tagger/uiset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tagger/uiset.py b/tagger/uiset.py index b44abec..d44ebbb 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -114,7 +114,7 @@ def update_input_glob(cls, input_glob: str) -> None: return cls.err.discard(msg) - recursive = getattr(shared.opts, 'tagger_batch_recursive', False) + recursive = getattr(shared.opts, 'tagger_batch_recursive', True) path_mtimes = [] for filename in glob(input_glob, recursive=recursive): if not os.path.isdir(filename): From 3e7e6d16d5017f2ba8d4a6fc55e9aabd487acbc6 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 19 Aug 2023 09:17:19 +0200 Subject: [PATCH 24/78] only the double asterisk will recurse --- tagger/uiset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tagger/uiset.py b/tagger/uiset.py index d44ebbb..34cf8cf 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -100,14 +100,14 @@ def update_input_glob(cls, input_glob: str) -> None: paths = [] # if there is no glob pattern, insert it automatically - if not input_glob.endswith('*'): + if not input_glob.endswith('**'): if not input_glob.endswith(os.sep): input_glob += os.sep - input_glob += '*' + input_glob += '**' # get root directory of input glob pattern base_dir = input_glob.replace('?', '*') - base_dir = base_dir.split(os.sep + '*').pop(0) + base_dir = base_dir.split(os.sep + '**').pop(0) msg = 'Invalid input directory' if not os.path.isdir(base_dir): cls.err.add(msg) From 83ed754eac16ba0c984e6c9e688ee007a9bbe504 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 19 Aug 2023 09:51:10 +0200 Subject: [PATCH 25/78] revert back to previous behaviour, but make the recursion more clear in the ui label --- tagger/ui.py | 3 ++- tagger/uiset.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tagger/ui.py b/tagger/ui.py index c2c114d..ed12e3e 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -197,7 +197,8 @@ def on_ui_tabs(): input_glob = preset.component( gr.Textbox, value='', - label='Input directory - See also settings tab.', + label='Input directory - To recurse use ** or */* ' + 'in your glob; also the check settings tab.', placeholder='/path/to/images or to/images/**/*' ) output_dir = preset.component( diff --git a/tagger/uiset.py b/tagger/uiset.py index 34cf8cf..d44ebbb 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -100,14 +100,14 @@ def update_input_glob(cls, input_glob: str) -> None: paths = [] # if there is no glob pattern, insert it automatically - if not input_glob.endswith('**'): + if not input_glob.endswith('*'): if not input_glob.endswith(os.sep): input_glob += os.sep - input_glob += '**' + input_glob += '*' # get root directory of input glob pattern base_dir = input_glob.replace('?', '*') - base_dir = base_dir.split(os.sep + '**').pop(0) + base_dir = base_dir.split(os.sep + '*').pop(0) msg = 'Invalid input directory' if not os.path.isdir(base_dir): cls.err.add(msg) From 0d46fefe55a02c4532fa0ccf39034dca6c0b3c9a Mon Sep 17 00:00:00 2001 From: R Date: Tue, 22 Aug 2023 13:13:37 +0200 Subject: [PATCH 26/78] ui.py typo --- tagger/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tagger/ui.py b/tagger/ui.py index ed12e3e..859a3d8 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -198,7 +198,7 @@ def on_ui_tabs(): gr.Textbox, value='', label='Input directory - To recurse use ** or */* ' - 'in your glob; also the check settings tab.', + 'in your glob; also check the settings tab.', placeholder='/path/to/images or to/images/**/*' ) output_dir = preset.component( From 8693e1363fcda04fbe15c55eff05ab5be6a5596e Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 26 Aug 2023 10:15:44 +0200 Subject: [PATCH 27/78] update changelog --- CHANGELOG.md | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c3f6b2..2c6a928 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,30 @@ -# v1.1.1 +# v1.1.2 (2023-08-26) + +explain recuisive path usage better in ui +Fix sending tags via buttons to txt2img and img2img +type additions, inadverteably pushed, later retouched. +allow setting gpu device via flag +Fix inverted cumulative checkbox +wrap_gradio_gpu_call fallback +Fix for preload shared access +preload update +A few ui changes +Fix not clearing the tags after writing them to files +Fix: Tags were still added, beyond count treshold +fix search/replace bug +(here int based weitghts were reverted) +circumvent when unable to load tensorflow +fix for too many exclude_tags +add db.json validation schema, add schema validation +return fix for fastapi +pick up huggingface cache dir from env, with default, configurable also via settings. +leave tensorflow requirements to the user. +Fix for Reappearance of gradio bug: duplicate image edit +(index based weights, but later reverted) +Instead of cache_dir use local_dir, leave cache to the user via env vars. +requirements fix for MacOS + +# v1.1.1 eada050 (2023-07-20) Internal cleanup, no separate interrogation for inverse Fix issues with search and sending selection to keep/exclude @@ -11,7 +37,9 @@ fix some hf download issues fixes for fastapi added ML-Danbooru support, thanks to [CCRcmcpe](github.com/CCRcmcpe) -# v1.1.0 + +# v1.1.0 87706b7 (2023-07-16) + fix: failed to install onnxruntime package on MacOS thanks to heady713 fastapi: remote unload model, picked up from [here](https://github.com/toriato/stable-diffusion-webui-wd14-tagger/pull/109) attribute error fix from aria1th also reported by yjunej @@ -39,7 +67,7 @@ changed internal error handling, It is a bit quirky, which I intend to fix, stil If you find it keeps complaining about an input field without reason, just try editing that one again (e.g. add a space there and remove it). -# v1.0.0 +# v1.0.0 a1b59d6 (2023-07-10) You may have to remove the presets/default.json and save a new one.witth your desired defaults. Otherwise checkboxes may not have the right default values. From 28cc72dec8bbc96558516891cbcdb83f362b0a99 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 26 Aug 2023 10:33:31 +0200 Subject: [PATCH 28/78] typos --- CHANGELOG.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c6a928..5e701c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,8 @@ # v1.1.2 (2023-08-26) -explain recuisive path usage better in ui +Explain recursive path usage better in ui Fix sending tags via buttons to txt2img and img2img -type additions, inadverteably pushed, later retouched. +type additions, inadvertently pushed, later retouched. allow setting gpu device via flag Fix inverted cumulative checkbox wrap_gradio_gpu_call fallback @@ -10,9 +10,9 @@ Fix for preload shared access preload update A few ui changes Fix not clearing the tags after writing them to files -Fix: Tags were still added, beyond count treshold +Fix: Tags were still added, beyond count threshold fix search/replace bug -(here int based weitghts were reverted) +(here int based weights were reverted) circumvent when unable to load tensorflow fix for too many exclude_tags add db.json validation schema, add schema validation @@ -21,8 +21,8 @@ pick up huggingface cache dir from env, with default, configurable also via sett leave tensorflow requirements to the user. Fix for Reappearance of gradio bug: duplicate image edit (index based weights, but later reverted) -Instead of cache_dir use local_dir, leave cache to the user via env vars. -requirements fix for MacOS +Instead of cache_dir use local_dir, leav + # v1.1.1 eada050 (2023-07-20) From e871aa7cfaecb76384e5590a134384fd625b6528 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 2 Sep 2023 22:57:25 +0200 Subject: [PATCH 29/78] Three scripts, a bash script to create per safetensors file the fraction of images contained to build the model that was marked with a particular tag. Secondly, a python script to compare the interrogation results (read from db.json) and find the top -c safetensors files that contain similar weights (or at least, that was the intention, there may be better algorithms to compare, but it seems to do the job). --- shell_scripts/compare_weighted_frequencies.py | 127 ++++++++++++++++++ shell_scripts/create_safetensors_db.sh | 102 ++++++++++++++ shell_scripts/model_grep.sh | 44 ++++++ 3 files changed, 273 insertions(+) create mode 100644 shell_scripts/compare_weighted_frequencies.py create mode 100644 shell_scripts/create_safetensors_db.sh create mode 100644 shell_scripts/model_grep.sh diff --git a/shell_scripts/compare_weighted_frequencies.py b/shell_scripts/compare_weighted_frequencies.py new file mode 100644 index 0000000..07d930a --- /dev/null +++ b/shell_scripts/compare_weighted_frequencies.py @@ -0,0 +1,127 @@ +from typing import Dict +from math import ceil +from json import load +import argparse +import re +from collections import defaultdict + +# read two json files, compare the weighted frequencies of the tags in the two +# files the first file is json and contains all safetensor files, with major +# sections and weighted tags + +# the second file is the result of an interrogations of images, with weighted +# tags. tags may be substrings of the tags in the first file + +# next argument is an interrogator id, a string +# optinally there are comma delimited images as arguments + +# in the end print the top ten safetensors and major sections that containss +# the tags that are most similar to the tags in the second file + +# all weights are between 0 and 1, higher is more important + + +# example usage: +# first run shell_scripts/create_safetensors_db.sh +# then interrogate an image in a subdirectory test/ + + +# cd stable-diffusion-webui/extensions/stable-diffusion-webui-wd14-tagger/ +# +# python shell_scripts/compare_weighted_frequencies.py safetensors_db.json \ +# test/db.json + +# # .. lists used interrogation models + + +# python shell_scripts/compare_weighted_frequencies.py safetensors_db.json \ +# -c 20 test/db.json + + +desc = 'Compare weighted frequencies of tags in two json file' +parser = argparse.ArgumentParser(description=desc) +hlp = 'number of results to print' +parser.add_argument('-c', '--count', default=10, type=int, help=hlp) +parser.add_argument('file1', help='all safetensors json file') +parser.add_argument('file2', help='image interrogation json file') +parser.add_argument('id', help='interrogator id', nargs='?', default="") +parser.add_argument('images', nargs='*', help='images', default=[]) +args = parser.parse_args() + + +with open(args.file1) as f: + all_sftns = load(f) + +with open(args.file2) as f: + data = load(f) + +query = data["query"] + +indices = set() +if args.id == "": + print("Missing interrogator id, contained are:") + uniq = set() + for k in data["query"]: + if k not in uniq: + print(k[64:]) + uniq.add(k[64:]) + exit(1) +else: + for k, t in data["query"].items(): + img_fn, idx = t + if k[64:] == args.id: + if len(args.images) > 0: + for i in args.images: + if img_fn[-len(i):] == i: + break + else: + continue + indices.add(int(idx)) + +interrogation_result = {} +for t, lst in data["tag"].items(): + wt = 0.0 + for stored in lst: + i = ceil(stored) - 1 + if i in indices: + wt += stored - i + if wt > 0.0: + interrogation_result[t] = wt / len(indices) + +scores: Dict[str, float] = defaultdict(float) + +for safetensor in all_sftns: + for major in all_sftns[safetensor]: + ct = len(all_sftns[safetensor][major]) + if ct == 0: + continue + + for tag, wt in interrogation_result.items(): + if tag in all_sftns[safetensor][major]: + sftns_wt = all_sftns[safetensor][major][tag] + n = (1.0 - abs(sftns_wt - wt)) + scores[safetensor + "\t" + major] += n / ct + else: + rex = re.compile(r'\b{}\b'.format(tag)) + t_len = len(tag) + # the tag may be a substring of a tag in the safetensor + # however only entire words are considered and a penalty if the + # string lenghts are close to each other + highest = 0.0 + for sftns_tag in all_sftns[safetensor][major]: + if rex.search(sftns_tag): + sftns_tag_len = len(sftns_tag) + sftns_wt = all_sftns[safetensor][major][sftns_tag] + n = (sftns_tag_len - t_len) / sftns_tag_len + n -= abs(sftns_wt - wt) + highest = max(highest, n) + scores[safetensor + "\t" + major] += highest / ct + +# sort the scores +sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) + +# print the top ten safetensors and major sections +for i in range(args.count): + if i >= len(sorted_scores): + break + print(sorted_scores[i][0] + "\t" + str(sorted_scores[i][1])) diff --git a/shell_scripts/create_safetensors_db.sh b/shell_scripts/create_safetensors_db.sh new file mode 100644 index 0000000..2538f49 --- /dev/null +++ b/shell_scripts/create_safetensors_db.sh @@ -0,0 +1,102 @@ +#!/bin/bash +# +# Create a database for safetensors wherein for the +# models each major tag, the occurrence frequency of +# each associated subtag is listed. +# +# requires https://github.com/by321/safetensors_util.git +# gnu parallel, jq, sed, awk +# + +# To build the safetensors_db.json database with +# "file.safetensors" { "major tag": { "tag1": , "tag2": .. } }: + + +# cd stable-diffusion-webui/extensions/stable-diffusion-webui-wd14-tagger/ +# git clone https://github.com/by321/safetensors_util.git +# +# bash shell_scripts/create_safetensors_db.sh -f -p ../.. -u safetensors_util/ -o safetensors_db.json +# +## now you can compare interrogation weights with the safetensors_db.json using +## shell_scripts/compare_weighted_frequencies.py, see there for usage. + + +# number of cpus to use by default or use -j to specify +ncpu=$(nproc --all) +[ $ncpu -gt 8 ] && ncpu=8 + +utilpath=./ +path=./ +out=safetensors_db.json +force=0 +while [ $# -gt 0 ]; do + case "$1" in + -h|--help) + echo "Usage: $0 [-f] [-p path] [-u utilpath] [-o out]" + echo " -p path path to stable-diffusion-webui" + echo " -u utilpath path to safetensors_util.py" + echo " -o out output file (default: safetensors_db.json)" + echo " -f force overwrite of output file" + echo " -j ncpu number of cpus to use (default: $ncpu)" + exit 0 + ;; + -p) path="$2/"; shift 2;; + -u) utilpath="$2/"; shift 2;; + -o) out="$2"; shift 2;; + -f) force=1; shift 1;; + -j) ncpu="$2"; shift 2;; + esac +done + +if [ ! -d "${path}models/Lora/" ]; then + echo "Error: ${path}models/Lora/ does not exist (use -p to specify path)" + exit 1 +fi + +if [ ! -e "${utilpath}/safetensors_util.py" ]; then + echo "Error: ${utilpath}/safetensors_util.py does not exist (use -u to specify path)" + exit 1 +fi + +if [ -e "${out}" -a $force -eq 0 ]; then + echo "Error: ${out} already exists (use -f to overwrite)" + exit 1 +fi + +ls -1 ${path}models/Lora/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | +sed -n '1b;p' | jq -r 'select(.__metadata__ != null) | .__metadata__ | .ss_tag_frequency | select( . != null )' 2>/dev/null | sed 's/\" /\"/' | +awk -v FS=': ' '{ + if (index(\$2, \"null\") > 0) next + o = index(\$0, \"{\") + if (o == 1) printf \"\\\"'{}'\\\": \" + if (o > 0) { + print \$0 + m = 0 + } else { + c = index(\$0, \"}\") + if (c > 0) { + L=\"\" + for (i in a) { + if (L != \"\") print \",\" + printf \"%s: %.6f\", i, a[i] / m + L = "x" + } + delete a + if (c == 1) print \$0\",\" + else print \"\n\"\$0 + } else { + x = index(\$2, \",\") + v = int(x != 0 ? substr(\$2, 1, x - 1) : \$2) + if (v > m) m = v + a[\$1] = v + } + } +}'" | sed -r ' +s/^/ /; +1s/^/{\n/; +s/\\"//g +s/^([ \t]+"[^"]+):*(: [01]+(\.[0-9]+)?,?)$/\1"\2/ +$s/,?$/\n}/ +' > "${out}" + + diff --git a/shell_scripts/model_grep.sh b/shell_scripts/model_grep.sh new file mode 100644 index 0000000..69cb59d --- /dev/null +++ b/shell_scripts/model_grep.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# requires jq, grep +# +# Usage: ./model_grep.sh [-p path] [-u utilpath] + +# number of cpus to use by default or use -j to specify +ncpu=$(nproc --all) +[ $ncpu -gt 8 ] && ncpu=8 + +utilpath=./ +path=./ +out=safetensors_db.json +while [ $# -gt 0 ]; do + case "$1" in + -h|--help) + echo "Usage: $0 [-p path] [-u utilpath]" + echo " -p path path to stable-diffusion-webui" + echo " -u utilpath path to safetensors_util.py" + echo " -j ncpu number of cpus to use (default: $ncpu)" + exit 0 + ;; + -p) path="$2/"; shift 2;; + -u) utilpath="$2/"; shift 2;; + -j) ncpu="$2"; shift 2;; + esac +done + +if [ ! -d "${path}models/Lora/" ]; then + echo "Error: ${path}models/Lora/ does not exist (use -p to specify path)" + exit 1 +fi + +if [ ! -e "${utilpath}/safetensors_util.py" ]; then + echo "Error: ${utilpath}/safetensors_util.py does not exist (use -u to specify path)" + exit 1 +fi + +if [ -z "$1" ]; then + echo "Usage: $0 [-p /path/to/stable-diffusion-webui (default: ./)] " + exit 1 +fi + +ls -1 ${path}models/Lora/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | +sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -o -E '\"[^\"]*${1}[^\"]*\": [0-9]+'| sed 's~^~'{}':~p'" From aa55f4e86a2a21bdfa3b2e26c531125d9870ef2f Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 3 Sep 2023 09:26:10 +0200 Subject: [PATCH 30/78] shell/scripts/compare_weighted_frequencies.py: If only interrogated in db.json with only one model, don't ask which model to compare for small corrections request directory with safetensors directly, don't assume a models/Lora subdir. --- shell_scripts/compare_weighted_frequencies.py | 33 +++++++++++-------- shell_scripts/create_safetensors_db.sh | 27 +++++++-------- shell_scripts/model_grep.sh | 23 ++++++------- 3 files changed, 45 insertions(+), 38 deletions(-) diff --git a/shell_scripts/compare_weighted_frequencies.py b/shell_scripts/compare_weighted_frequencies.py index 07d930a..a7c8496 100644 --- a/shell_scripts/compare_weighted_frequencies.py +++ b/shell_scripts/compare_weighted_frequencies.py @@ -59,24 +59,29 @@ indices = set() if args.id == "": - print("Missing interrogator id, contained are:") uniq = set() for k in data["query"]: if k not in uniq: - print(k[64:]) uniq.add(k[64:]) - exit(1) -else: - for k, t in data["query"].items(): - img_fn, idx = t - if k[64:] == args.id: - if len(args.images) > 0: - for i in args.images: - if img_fn[-len(i):] == i: - break - else: - continue - indices.add(int(idx)) + if len(uniq) != 1: + print("Missing interrogator id, contained are:") + for k in uniq: + print(k) + exit(1) + else: + # use the only one + args.id = uniq.pop() + +for k, t in data["query"].items(): + img_fn, idx = t + if k[64:] == args.id: + if len(args.images) > 0: + for i in args.images: + if img_fn[-len(i):] == i: + break + else: + continue + indices.add(int(idx)) interrogation_result = {} for t, lst in data["tag"].items(): diff --git a/shell_scripts/create_safetensors_db.sh b/shell_scripts/create_safetensors_db.sh index 2538f49..f2f0381 100644 --- a/shell_scripts/create_safetensors_db.sh +++ b/shell_scripts/create_safetensors_db.sh @@ -15,7 +15,7 @@ # cd stable-diffusion-webui/extensions/stable-diffusion-webui-wd14-tagger/ # git clone https://github.com/by321/safetensors_util.git # -# bash shell_scripts/create_safetensors_db.sh -f -p ../.. -u safetensors_util/ -o safetensors_db.json +# bash shell_scripts/create_safetensors_db.sh -f -p ../../models/Lora -u safetensors_util/ -o safetensors_db.json # ## now you can compare interrogation weights with the safetensors_db.json using ## shell_scripts/compare_weighted_frequencies.py, see there for usage. @@ -25,31 +25,32 @@ ncpu=$(nproc --all) [ $ncpu -gt 8 ] && ncpu=8 -utilpath=./ -path=./ -out=safetensors_db.json +path=. +utilpath=. force=0 +out=safetensors_db.json + while [ $# -gt 0 ]; do case "$1" in -h|--help) - echo "Usage: $0 [-f] [-p path] [-u utilpath] [-o out]" - echo " -p path path to stable-diffusion-webui" + echo "Usage: $0 [-j ncpu] [-p path] [-u utilpath] [-f] [-o out]" + echo " -j ncpu number of cpus to use (default: $ncpu)" + echo " -p path path to directory containing safetensor models (default: ./)" echo " -u utilpath path to safetensors_util.py" - echo " -o out output file (default: safetensors_db.json)" echo " -f force overwrite of output file" - echo " -j ncpu number of cpus to use (default: $ncpu)" + echo " -o out output file (default: safetensors_db.json)" exit 0 ;; + -j) ncpu="$2"; shift 2;; -p) path="$2/"; shift 2;; -u) utilpath="$2/"; shift 2;; - -o) out="$2"; shift 2;; -f) force=1; shift 1;; - -j) ncpu="$2"; shift 2;; + -o) out="$2"; shift 2;; esac done -if [ ! -d "${path}models/Lora/" ]; then - echo "Error: ${path}models/Lora/ does not exist (use -p to specify path)" +if [ ! -d "${path}" ]; then + echo "Error: '${path}' does not exist (use -p to specify path)" exit 1 fi @@ -63,7 +64,7 @@ if [ -e "${out}" -a $force -eq 0 ]; then exit 1 fi -ls -1 ${path}models/Lora/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | +ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | sed -n '1b;p' | jq -r 'select(.__metadata__ != null) | .__metadata__ | .ss_tag_frequency | select( . != null )' 2>/dev/null | sed 's/\" /\"/' | awk -v FS=': ' '{ if (index(\$2, \"null\") > 0) next diff --git a/shell_scripts/model_grep.sh b/shell_scripts/model_grep.sh index 69cb59d..ad095f8 100644 --- a/shell_scripts/model_grep.sh +++ b/shell_scripts/model_grep.sh @@ -7,26 +7,27 @@ ncpu=$(nproc --all) [ $ncpu -gt 8 ] && ncpu=8 -utilpath=./ -path=./ +utilpath=. +path=. out=safetensors_db.json while [ $# -gt 0 ]; do case "$1" in -h|--help) - echo "Usage: $0 [-p path] [-u utilpath]" - echo " -p path path to stable-diffusion-webui" - echo " -u utilpath path to safetensors_util.py" + echo "Usage: $0 [ -j ncpu] [-p path] [-u utilpath] " echo " -j ncpu number of cpus to use (default: $ncpu)" + echo " -p path path to directory containing safetensor models (default: ./)" + echo " -u utilpath path to safetensors_util.py" + echo " extended regex to match against model names" exit 0 ;; - -p) path="$2/"; shift 2;; - -u) utilpath="$2/"; shift 2;; -j) ncpu="$2"; shift 2;; + -p) path="$2"; shift 2;; + -u) utilpath="$2"; shift 2;; esac done -if [ ! -d "${path}models/Lora/" ]; then - echo "Error: ${path}models/Lora/ does not exist (use -p to specify path)" +if [ ! -d "${path}" ]; then + echo "Error: ${path} does not exist (use -p to specify path)" exit 1 fi @@ -36,9 +37,9 @@ if [ ! -e "${utilpath}/safetensors_util.py" ]; then fi if [ -z "$1" ]; then - echo "Usage: $0 [-p /path/to/stable-diffusion-webui (default: ./)] " + echo "Usage: $0 [-p /path/to/stable-diffusion-webui/models/Lora (default: .)] " exit 1 fi -ls -1 ${path}models/Lora/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | +ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -o -E '\"[^\"]*${1}[^\"]*\": [0-9]+'| sed 's~^~'{}':~p'" From 836a6c8ec7b02ec8212289447dfb004ef2a1f454 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 3 Sep 2023 13:13:43 +0200 Subject: [PATCH 31/78] nit --- shell_scripts/model_grep.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/shell_scripts/model_grep.sh b/shell_scripts/model_grep.sh index ad095f8..7b21a52 100644 --- a/shell_scripts/model_grep.sh +++ b/shell_scripts/model_grep.sh @@ -7,9 +7,9 @@ ncpu=$(nproc --all) [ $ncpu -gt 8 ] && ncpu=8 -utilpath=. path=. -out=safetensors_db.json +utilpath=. + while [ $# -gt 0 ]; do case "$1" in -h|--help) From 08a4d81b33dd1af1fb61c24e3fd4812bd3524aba Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 17:19:30 +0200 Subject: [PATCH 32/78] deprecation warning # Conflicts: # tagger/settings.py --- tagger/settings.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tagger/settings.py b/tagger/settings.py index 07b58fe..cf309b9 100644 --- a/tagger/settings.py +++ b/tagger/settings.py @@ -2,7 +2,7 @@ import os from typing import List from modules import shared # pylint: disable=import-error -from gradio import inputs as gr +import gradio as gr from huggingface_hub import hf_hub_download # kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400! @@ -10,10 +10,8 @@ DEFAULT_OFF = '[name].[output_extension]' -HF_CACHE = os.environ.get( - 'HUGGINGFACE_HUB_CACHE', # defaults to "$HF_HOME/hub" - str(os.path.join(shared.models_path, 'interrogators'))) - +HF_CACHE = os.environ.get('HF_HOME', os.environ.get('HUGGINGFACE_HUB_CACHE', + str(os.path.join(shared.models_path, 'interrogators')))) def slider_wrapper(value, elem_id, **kwargs): # required or else gradio will throw errors @@ -124,12 +122,13 @@ def on_ui_settings(): section=section, ), ) + # see huggingface_hub guides/manage-cache shared.opts.add_option( - key='tagger_hf_hub_down_opts', + key='tagger_hf_cache_dir', info=shared.OptionInfo( - str(f'cache_dir="{HF_CACHE}"'), - label='HuggingFace parameters, Comma delimited: arg=value, ' - 'see huggingface_hub docs for available or leave alone.', + HF_CACHE, + label='HuggingFace cache directory, ' + 'see huggingface_hub guides/manage-cache', section=section, ), ) From 7e0c59526dc7d0661fc4e24b18bb894c903d4c8c Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Mon, 11 Sep 2023 20:38:43 +0200 Subject: [PATCH 33/78] observing dengrixionghnu's branch the version should probably not pass a callable but rather a version string. --- tagger/ui.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tagger/ui.py b/tagger/ui.py index 859a3d8..094c437 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -8,8 +8,7 @@ try: from tensorflow import __version__ as tf_version except ImportError: - def tf_version(): - return '0.0.0' + tf_version = '0.0.0' from html import escape as html_esc From 3486bf608da63b1a67efad922f86d4c517ac299c Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 13 Sep 2023 20:28:20 +0200 Subject: [PATCH 34/78] because people were trying to configure exclude tags from the settings tab zucht --- tagger/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tagger/settings.py b/tagger/settings.py index cf309b9..0695b9b 100644 --- a/tagger/settings.py +++ b/tagger/settings.py @@ -98,7 +98,7 @@ def on_ui_settings(): key='tagger_repl_us_excl', info=shared.OptionInfo( DEFAULT_KAMOJIS, - label='Excudes (split by comma)', + label='Underscore replacement excudes (split by comma)', section=section, ), ) From fdda828627a0f576d3eef59edf5b1bb4f1028c79 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 13 Sep 2023 20:33:05 +0200 Subject: [PATCH 35/78] excudes l --- tagger/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tagger/settings.py b/tagger/settings.py index 0695b9b..9cd2200 100644 --- a/tagger/settings.py +++ b/tagger/settings.py @@ -98,7 +98,7 @@ def on_ui_settings(): key='tagger_repl_us_excl', info=shared.OptionInfo( DEFAULT_KAMOJIS, - label='Underscore replacement excudes (split by comma)', + label='Underscore replacement excludes (split by comma)', section=section, ), ) From 3c186320947fd54fea8fbf2af3802da33c813316 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 18:08:04 +0200 Subject: [PATCH 36/78] more api tweaks # Conflicts: # tagger/api.py --- tagger/api.py | 33 +++++++++++++++++++++++---------- tagger/uiset.py | 4 ++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index e64ee30..01434d9 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -80,18 +80,31 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if req.model not in Interrogator.entries.keys(): raise HTTPException(404, 'Model not found') + req.model = [req.model] image = decode_base64_to_image(req.image) - interrogator = Interrogator.entries[req.model] - - with self.queue_lock: - QData.tags.clear() - QData.ratings.clear() - QData.in_db.clear() - QData.for_tags_file.clear() - data = ('', '', '') + interrogator.interrogate(image) - QData.apply_filters(data) - output = QData.finalize(1) + QData.tags.clear() + QData.ratings.clear() + QData.in_db.clear() + QData.for_tags_file.clear() + + # allow overriding of default values + if req.threshold: + QData.threshold = req.threshold + if req.tag_frac_threshold: + QData.tag_frac_threshold = req.tag_frac_threshold + if req.count_threshold: + QData.count_threshold = req.count_threshold + + for model in req.model: + with self.queue_lock: + interrogator = utils.interrogators[model] + data = ('', '', '') + interrogator.interrogate(image) + QData.apply_filters(data) + if req.auto_unload: + interrogator.unload() + + output = QData.finalize(1) return models.TaggerInterrogateResponse( caption={ diff --git a/tagger/uiset.py b/tagger/uiset.py index d44ebbb..d887875 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -204,6 +204,7 @@ class QData: replace_tags = [] threshold = 0.35 tag_frac_threshold = 0.05 + count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100) # read from db.json, update with what should be written to db.json: json_db = None @@ -500,8 +501,7 @@ def apply_filters(cls, data) -> None: cls.weighed[0][rating].append(val + index) cls.ratings[rating] += val - count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100) - max_ct = count_threshold - len(cls.add_tags) + max_ct = cls.count_threshold - len(cls.add_tags) count = 0 # loop over tags with db update for tag, val in tags: From 16239dc70f0dc0c7e90a1a2d3b82b61fbfc58e70 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 25 Jul 2023 23:50:11 +0200 Subject: [PATCH 37/78] or rather.. lock like this. --- tagger/api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 01434d9..cae0418 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -97,12 +97,12 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): QData.count_threshold = req.count_threshold for model in req.model: + interrogator = utils.interrogators[model] with self.queue_lock: - interrogator = utils.interrogators[model] data = ('', '', '') + interrogator.interrogate(image) - QData.apply_filters(data) - if req.auto_unload: - interrogator.unload() + QData.apply_filters(data) + if req.auto_unload: + interrogator.unload() output = QData.finalize(1) From bd15c8c7bf1fd159a8c5fd16f5a07c1840bfd938 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 18:11:15 +0200 Subject: [PATCH 38/78] The api model also needs to change. Allow a comma delimited list of interrogator to receive a combined results, fix some bugs. # Conflicts: # tagger/api.py --- tagger/api.py | 11 +++++++---- tagger/api_models.py | 21 ++++++++++++++++++++- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index cae0418..1789fc6 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -78,9 +78,12 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if req.image is None: raise HTTPException(404, 'Image not found') - if req.model not in Interrogator.entries.keys(): - raise HTTPException(404, 'Model not found') - req.model = [req.model] + req_models = [] + for i in map(str.strip, req.model.split(',')): + if i in utils.interrogators.keys(): + req_models.append(i) + else: + raise HTTPException(404, f"Model '{i}' not found") image = decode_base64_to_image(req.image) QData.tags.clear() @@ -96,7 +99,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if req.count_threshold: QData.count_threshold = req.count_threshold - for model in req.model: + for model in req_models: interrogator = utils.interrogators[model] with self.queue_lock: data = ('', '', '') + interrogator.interrogate(image) diff --git a/tagger/api_models.py b/tagger/api_models.py index 7f2cd20..c651fc0 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -9,7 +9,7 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): """Interrogate request model""" model: str = Field( title='Model', - description='The interrogate model used.' + description='The interrogate model(s) used. Comma separated.', ) threshold: float = Field( @@ -19,6 +19,25 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): ge=0, le=1 ) + tag_frac_threshold: float = Field( + default=0.05, + title='Amongst interrogations tag fraction threshold', + description='', + ge=0, + le=1 + ) + count_threshold: float = Field( + default=100, + title='Count threshold', + description='', + ge=1, + le=1000000 + ) + auto_unload: bool = Field( + default=True, + title='Auto unload', + description='Unload each model after interrogation.' + ) class TaggerInterrogateResponse(BaseModel): From 0d8055b608bf0470378ccc5d894d5979d90f0b1e Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 18:40:42 +0200 Subject: [PATCH 39/78] Actually I think first posting, then batch processing multiple is more feasible # Conflicts: # tagger/api.py --- tagger/api.py | 97 +++++++++++++++++++++++++++----------------- tagger/api_models.py | 55 +++++++++++++------------ 2 files changed, 90 insertions(+), 62 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 1789fc6..ad9346a 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -28,6 +28,7 @@ def __init__( self.app = app self.queue_lock = qlock self.prefix = prefix + self.images = {} self.add_api_route( 'interrogate', @@ -44,11 +45,23 @@ def __init__( ) self.add_api_route( - "unload-interrogators", + 'unload-interrogators', self.endpoint_unload_interrogators, - methods=["POST"], + methods=['POST'], response_model=str, ) + self.add_api_route( + 'queue-image', + self.endpoint_queue_image, + methods=['POST'], + response_model=models.QueueImageResponse + ) + self.add_api_route( + 'batch-process', + self.endpoint_batch, + methods=['POST'], + response_model=models.BatchResponse + ) def auth(self, creds: HTTPBasicCredentials = None): if creds is None: @@ -75,46 +88,56 @@ def add_api_route(self, path: str, endpoint: Callable, **kwargs): return self.app.add_api_route(path, endpoint, **kwargs) def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): + """ one file interrogation """ if req.image is None: raise HTTPException(404, 'Image not found') - req_models = [] - for i in map(str.strip, req.model.split(',')): - if i in utils.interrogators.keys(): - req_models.append(i) - else: - raise HTTPException(404, f"Model '{i}' not found") + if req.model not in utils.interrogators: + raise HTTPException(404, 'Model not found') image = decode_base64_to_image(req.image) - QData.tags.clear() - QData.ratings.clear() - QData.in_db.clear() - QData.for_tags_file.clear() - - # allow overriding of default values - if req.threshold: - QData.threshold = req.threshold - if req.tag_frac_threshold: - QData.tag_frac_threshold = req.tag_frac_threshold - if req.count_threshold: - QData.count_threshold = req.count_threshold - - for model in req_models: - interrogator = utils.interrogators[model] - with self.queue_lock: - data = ('', '', '') + interrogator.interrogate(image) - QData.apply_filters(data) - if req.auto_unload: - interrogator.unload() - - output = QData.finalize(1) - - return models.TaggerInterrogateResponse( - caption={ - **output[0], - **output[1], - **output[2], - }) + with self.queue_lock: + interrogator = utils.interrogators[req.model] + data = interrogator.interrogate(image) + res = {**data[0], **data[1]} + + return models.TaggerInterrogateResponse(res) + + def endpoint_queue_image(self, req: models.TaggerInterrogateRequest): + """ post image to queue """ + if req.image is None: + raise HTTPException(404, 'Image not found') + + # TODO make this a command line option + if len(self.images) >= getattr(shared.cmd_opts, 'queue_size', 512): + raise HTTPException(429, 'Queue is full') + + self.images[req.name] = decode_base64_to_image(req.image) + + return models.TaggerPostImageResponse(True) + + def endpoint_batch(self, req: models.TaggerBatchRequest): + """ batch interrogation """ + if req.image is None: + raise HTTPException(404, 'Image not found') + + if req.model not in utils.interrogators: + raise HTTPException(404, 'Model not found') + + res = {} + + with self.queue_lock: + interrogator = utils.interrogators[req.model] + for name, i in self.images.items(): + res[name] = interrogator.interrogate(i) + + # last image + image = decode_base64_to_image(req.image) + res[req.name] = interrogator.interrogate(image) + + self.images.clear() + + return models.TaggerBatchResponse(res) def endpoint_interrogators(self): return models.InterrogatorsResponse( diff --git a/tagger/api_models.py b/tagger/api_models.py index c651fc0..65caf49 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -9,34 +9,26 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): """Interrogate request model""" model: str = Field( title='Model', - description='The interrogate model(s) used. Comma separated.', + description='The interrogate model used.', ) - threshold: float = Field( - default=0.35, - title='Threshold', - description='', - ge=0, - le=1 - ) - tag_frac_threshold: float = Field( - default=0.05, - title='Amongst interrogations tag fraction threshold', - description='', - ge=0, - le=1 + +class TaggerQueueImageRequest(sd_models.InterrogateRequest): + name: str = Field( + title='Name', + description='Only queue the image, under this name.', ) - count_threshold: float = Field( - default=100, - title='Count threshold', - description='', - ge=1, - le=1000000 + + +class TaggerBatchRequest(sd_models.InterrogateRequest): + """Batch request model""" + model: str = Field( + title='Model', + description='The interrogate model used.', ) - auto_unload: bool = Field( - default=True, - title='Auto unload', - description='Unload each model after interrogation.' + name: str = Field( + title='Name', + description='name of the last image', ) @@ -44,7 +36,7 @@ class TaggerInterrogateResponse(BaseModel): """Interrogate response model""" caption: Dict[str, float] = Field( title='Caption', - description='The generated caption for the image.' + description='The generated captions for the image.' ) @@ -54,3 +46,16 @@ class InterrogatorsResponse(BaseModel): title='Models', description='' ) + + +class QueueImageResponse(BaseModel): + """Queue image response model""" + pass + + +class BatchResponse(BaseModel): + """Batch response model""" + captions: Dict[str, Dict[str, float]] = Field( + title='Captions', + description='The generated captions for the images.' + ) From 66f8253ae4aca5fa4f7e8ee4b833137da448e9f8 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 26 Jul 2023 22:00:22 +0200 Subject: [PATCH 40/78] allow weight based filtering --- tagger/api.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tagger/api.py b/tagger/api.py index ad9346a..272e976 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -99,6 +99,10 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): with self.queue_lock: interrogator = utils.interrogators[req.model] data = interrogator.interrogate(image) + if req.threshold > 0.0: + data[1] = { + k: v for k, v in data[1].items() if v > req.threshold + } res = {**data[0], **data[1]} return models.TaggerInterrogateResponse(res) @@ -129,7 +133,12 @@ def endpoint_batch(self, req: models.TaggerBatchRequest): with self.queue_lock: interrogator = utils.interrogators[req.model] for name, i in self.images.items(): - res[name] = interrogator.interrogate(i) + data = interrogator.interrogate(i) + if req.threshold > 0.0: + data[1] = { + k: v for k, v in data[1].items() if v > req.threshold + } + res[name] = {**data[0], **data[1]} # last image image = decode_base64_to_image(req.image) From bf792e87227b809c5350e43a5d679b314137c0cf Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 26 Jul 2023 22:01:02 +0200 Subject: [PATCH 41/78] .. and this --- tagger/api_models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tagger/api_models.py b/tagger/api_models.py index 65caf49..4e35cae 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -11,6 +11,10 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): title='Model', description='The interrogate model used.', ) + threshold: float = Field( + title='Threshold', + description='The threshold used for the interrogate model.', + ) class TaggerQueueImageRequest(sd_models.InterrogateRequest): @@ -30,6 +34,10 @@ class TaggerBatchRequest(sd_models.InterrogateRequest): title='Name', description='name of the last image', ) + threshold: float = Field( + title='Threshold', + description='The threshold used for the interrogate model.', + ) class TaggerInterrogateResponse(BaseModel): From 189cc0c285a9988140d9c2dbe89f05f0225f9c2e Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 26 Aug 2023 12:48:31 +0200 Subject: [PATCH 42/78] typing --- tagger/api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 272e976..b106017 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -1,5 +1,5 @@ """API module for FastAPI""" -from typing import Callable +from typing import Callable, Dict, Optional from threading import Lock from secrets import compare_digest @@ -17,7 +17,7 @@ class Api: """Api class for FastAPI""" def __init__( - self, app: FastAPI, qlock: Lock, prefix: str = None + self, app: FastAPI, qlock: Lock, prefix: Optional[str] = None ) -> None: if shared.cmd_opts.api_auth: self.credentials = {} @@ -28,7 +28,7 @@ def __init__( self.app = app self.queue_lock = qlock self.prefix = prefix - self.images = {} + self.images: Dict[str, object] = {} self.add_api_route( 'interrogate', @@ -63,7 +63,7 @@ def __init__( response_model=models.BatchResponse ) - def auth(self, creds: HTTPBasicCredentials = None): + def auth(self, creds: Optional[HTTPBasicCredentials] = None): if creds is None: creds = Depends(HTTPBasic()) if creds.username in self.credentials: From ad5e47425a0140698396fea50c604b271b527de5 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 18:42:05 +0200 Subject: [PATCH 43/78] Fix/use consistent prefix in names for responses # Conflicts: # tagger/api.py --- tagger/api.py | 12 ++++++------ tagger/api_models.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index b106017..5fa88ff 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -41,7 +41,7 @@ def __init__( 'interrogators', self.endpoint_interrogators, methods=['GET'], - response_model=models.InterrogatorsResponse + response_model=models.TaggerInterrogatorsResponse ) self.add_api_route( @@ -54,13 +54,13 @@ def __init__( 'queue-image', self.endpoint_queue_image, methods=['POST'], - response_model=models.QueueImageResponse + response_model=models.TaggerQueueImageResponse ) self.add_api_route( 'batch-process', self.endpoint_batch, methods=['POST'], - response_model=models.BatchResponse + response_model=models.TaggerBatchResponse ) def auth(self, creds: Optional[HTTPBasicCredentials] = None): @@ -118,7 +118,7 @@ def endpoint_queue_image(self, req: models.TaggerInterrogateRequest): self.images[req.name] = decode_base64_to_image(req.image) - return models.TaggerPostImageResponse(True) + return models.TaggerQueueImageResponse(True) def endpoint_batch(self, req: models.TaggerBatchRequest): """ batch interrogation """ @@ -149,8 +149,8 @@ def endpoint_batch(self, req: models.TaggerBatchRequest): return models.TaggerBatchResponse(res) def endpoint_interrogators(self): - return models.InterrogatorsResponse( - models=list(Interrogator.entries.keys()) + return models.TaggerInterrogatorsResponse( + models=list(utils.interrogators.keys()) ) def endpoint_unload_interrogators(self): diff --git a/tagger/api_models.py b/tagger/api_models.py index 4e35cae..d77c427 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -48,7 +48,7 @@ class TaggerInterrogateResponse(BaseModel): ) -class InterrogatorsResponse(BaseModel): +class TaggerInterrogatorsResponse(BaseModel): """Interrogators response model""" models: List[str] = Field( title='Models', @@ -56,12 +56,12 @@ class InterrogatorsResponse(BaseModel): ) -class QueueImageResponse(BaseModel): +class TaggerQueueImageResponse(BaseModel): """Queue image response model""" pass -class BatchResponse(BaseModel): +class TaggerBatchResponse(BaseModel): """Batch response model""" captions: Dict[str, Dict[str, float]] = Field( title='Captions', From 55acbfa3f189aaad268ae255b729236f45b4b5fc Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 26 Aug 2023 12:49:56 +0200 Subject: [PATCH 44/78] clobber image if already queued --- tagger/api.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tagger/api.py b/tagger/api.py index 5fa88ff..38e4951 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -116,6 +116,13 @@ def endpoint_queue_image(self, req: models.TaggerInterrogateRequest): if len(self.images) >= getattr(shared.cmd_opts, 'queue_size', 512): raise HTTPException(429, 'Queue is full') + # clobber existing image + if req.name in self.images: + i = 0 + while f'{req.name}.{i}' in self.images: + i = i + 1 + req.name = f'{req.name}.{i}' + self.images[req.name] = decode_base64_to_image(req.image) return models.TaggerQueueImageResponse(True) From f4b0ee8c0a7c7110379444eb86ceebb6a1d1bed3 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 26 Aug 2023 19:36:15 +0200 Subject: [PATCH 45/78] no assignment to tuple --- tagger/api.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 38e4951..57f4ebe 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -98,12 +98,10 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): image = decode_base64_to_image(req.image) with self.queue_lock: interrogator = utils.interrogators[req.model] - data = interrogator.interrogate(image) + rating, tag = interrogator.interrogate(image) if req.threshold > 0.0: - data[1] = { - k: v for k, v in data[1].items() if v > req.threshold - } - res = {**data[0], **data[1]} + tag = {k: v for k, v in tag.items() if v > req.threshold} + res = {**rating, **tag} return models.TaggerInterrogateResponse(res) @@ -140,12 +138,10 @@ def endpoint_batch(self, req: models.TaggerBatchRequest): with self.queue_lock: interrogator = utils.interrogators[req.model] for name, i in self.images.items(): - data = interrogator.interrogate(i) + rating, tag = interrogator.interrogate(i) if req.threshold > 0.0: - data[1] = { - k: v for k, v in data[1].items() if v > req.threshold - } - res[name] = {**data[0], **data[1]} + tag = {k: v for k, v in tag.items() if v > req.threshold} + res[name] = {**rating, **tag} # last image image = decode_base64_to_image(req.image) From fe07da13574f7fb9dfb676f94227f37e84429f37 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 26 Aug 2023 20:25:25 +0200 Subject: [PATCH 46/78] The pydantic models require named, not positional arguments --- tagger/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 57f4ebe..83c5472 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -103,7 +103,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): tag = {k: v for k, v in tag.items() if v > req.threshold} res = {**rating, **tag} - return models.TaggerInterrogateResponse(res) + return models.TaggerInterrogateResponse(caption=res) def endpoint_queue_image(self, req: models.TaggerInterrogateRequest): """ post image to queue """ @@ -123,7 +123,7 @@ def endpoint_queue_image(self, req: models.TaggerInterrogateRequest): self.images[req.name] = decode_base64_to_image(req.image) - return models.TaggerQueueImageResponse(True) + return models.TaggerQueueImageResponse() def endpoint_batch(self, req: models.TaggerBatchRequest): """ batch interrogation """ @@ -149,7 +149,7 @@ def endpoint_batch(self, req: models.TaggerBatchRequest): self.images.clear() - return models.TaggerBatchResponse(res) + return models.TaggerBatchResponse(captions=res) def endpoint_interrogators(self): return models.TaggerInterrogatorsResponse( From cab60e7b5fd98977e4382d924ba4d0ce3b0efde1 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 2 Sep 2023 00:35:19 +0200 Subject: [PATCH 47/78] in progress --- .../tag_based_image_dedup.sh | 0 tagger/api.py | 163 +++++++++++------- tagger/api_models.py | 38 +--- 3 files changed, 104 insertions(+), 97 deletions(-) rename tag_based_image_dedup.sh => bash_scripts/tag_based_image_dedup.sh (100%) diff --git a/tag_based_image_dedup.sh b/bash_scripts/tag_based_image_dedup.sh similarity index 100% rename from tag_based_image_dedup.sh rename to bash_scripts/tag_based_image_dedup.sh diff --git a/tagger/api.py b/tagger/api.py index 83c5472..2614966 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -1,7 +1,10 @@ """API module for FastAPI""" -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple from threading import Lock from secrets import compare_digest +import asyncio +from collections import defaultdict +from itertools import cycle from modules import shared # pylint: disable=import-error from modules.api.api import decode_base64_to_image # pylint: disable=E0401 @@ -26,9 +29,22 @@ def __init__( self.credentials[user] = password self.app = app + self.queue: Dict[str, asyncio.Queue] = {} + self.results = Dict[str, Dict[str, Dict[str, float]]] = {} self.queue_lock = qlock + self.running_batches: Dict[str, Dict[str, asyncio.Task]] = \ + defaultdict(dict) + + self.runner: Optional[asyncio.Task] = None self.prefix = prefix - self.images: Dict[str, object] = {} + + self.images: Dict[str, Dict[str, Dict[str, tuple[object, float]]]] = \ + defaultdict(lambda: defaultdict(dict)) + + self.finished_batches: Dict[str, Dict[str, asyncio.Task]] = \ + defaultdict(dict) + + self.reached_end: Dict[str, Dict[str, bool]] = defaultdict(dict) self.add_api_route( 'interrogate', @@ -50,18 +66,52 @@ def __init__( methods=['POST'], response_model=str, ) - self.add_api_route( - 'queue-image', - self.endpoint_queue_image, - methods=['POST'], - response_model=models.TaggerQueueImageResponse - ) - self.add_api_route( - 'batch-process', - self.endpoint_batch, - methods=['POST'], - response_model=models.TaggerBatchResponse - ) + + async def process_model(self, model: str) -> None: + """Process a batch of images""" + res: Dict[str, Dict[str, float]] = defaultdict(dict) + while len(self.queue[model]) > 0: + skipped = 0 + for queue in self.queue[model]: + while True: + try: + queue_name, name, image, threshold = await queue.get_nowait() + except asyncio.QueueEmpty: + skipped += 1 + break # if empty move on to next queue for same model (if any) + if name == "": + # This is the end of the queue + self.results[queue_name] = res[queue_name] + del res[queue_name] + skipped += 1 + break + # No queue or queued_name, processes instead of queuing + res[queue_name][name] = await self.endpoint_interrogate( + models.TaggerInterrogateRequest( + image=image, + model=model, + threshold=threshold, + queue="", + queued_name="" + ) + ) + if skipped == len(self.queue[model]): + # if all queues for this model are empty, postpone interrogration + # for this model and do other models first + break + + async def batch_process(self) -> None: + while len(self.queue) > 0: + for model in self.queue: + if model not in self.running_batches: + self.running_batches[model] = asyncio.create_task( + self.process_model(model) + ) + elif len(self.queue[model]) == 0: + await self.running_batches[model] + del self.running_batches[model] + await asyncio.sleep(0.1) + def auth(self, creds: Optional[HTTPBasicCredentials] = None): if creds is None: @@ -96,61 +146,46 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): raise HTTPException(404, 'Model not found') image = decode_base64_to_image(req.image) - with self.queue_lock: - interrogator = utils.interrogators[req.model] - rating, tag = interrogator.interrogate(image) + res: Dict[str, Dict[str, float]] = defaultdict(dict) + m, q, n = (req.model, req.queue, req.name) + + if n == '' and q in self.running_batches[m]: + # wait for batch to finish + res = self.running_batches[m][q].result() + del self.running_batches[m][q] + + elif q != '': + # queueing interrogation of the image + with self.queue_lock: + # check before populating the queue + initialize_runner = len(self.queue) == 0 + + # create to retreive data when a queue is finished + if q not in self.queue[m]: + self.queue[m][q] = asyncio.Queue() + + # clobber existing image + if n in self.images[m][q]: + i = 0 + while f'{n}.{i}' in self.images[m][q]: + i = i + 1 + n = f'{n}.{i}' + + # add image to queue + self.queue[m][q].put_nowait((n, image, req.threshold)) + if initialize_runner: + self.runner = asyncio.create_task(self.batch_process()) + else: + interrogator = utils.interrogators[m] + with self.queue_lock: + rating, tag = interrogator.interrogate(image) + if req.threshold > 0.0: tag = {k: v for k, v in tag.items() if v > req.threshold} - res = {**rating, **tag} + res[n] = {**rating, **tag} return models.TaggerInterrogateResponse(caption=res) - def endpoint_queue_image(self, req: models.TaggerInterrogateRequest): - """ post image to queue """ - if req.image is None: - raise HTTPException(404, 'Image not found') - - # TODO make this a command line option - if len(self.images) >= getattr(shared.cmd_opts, 'queue_size', 512): - raise HTTPException(429, 'Queue is full') - - # clobber existing image - if req.name in self.images: - i = 0 - while f'{req.name}.{i}' in self.images: - i = i + 1 - req.name = f'{req.name}.{i}' - - self.images[req.name] = decode_base64_to_image(req.image) - - return models.TaggerQueueImageResponse() - - def endpoint_batch(self, req: models.TaggerBatchRequest): - """ batch interrogation """ - if req.image is None: - raise HTTPException(404, 'Image not found') - - if req.model not in utils.interrogators: - raise HTTPException(404, 'Model not found') - - res = {} - - with self.queue_lock: - interrogator = utils.interrogators[req.model] - for name, i in self.images.items(): - rating, tag = interrogator.interrogate(i) - if req.threshold > 0.0: - tag = {k: v for k, v in tag.items() if v > req.threshold} - res[name] = {**rating, **tag} - - # last image - image = decode_base64_to_image(req.image) - res[req.name] = interrogator.interrogate(image) - - self.images.clear() - - return models.TaggerBatchResponse(captions=res) - def endpoint_interrogators(self): return models.TaggerInterrogatorsResponse( models=list(utils.interrogators.keys()) diff --git a/tagger/api_models.py b/tagger/api_models.py index d77c427..1f61881 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -15,34 +15,19 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): title='Threshold', description='The threshold used for the interrogate model.', ) - - -class TaggerQueueImageRequest(sd_models.InterrogateRequest): - name: str = Field( - title='Name', - description='Only queue the image, under this name.', - ) - - -class TaggerBatchRequest(sd_models.InterrogateRequest): - """Batch request model""" - model: str = Field( - title='Model', - description='The interrogate model used.', + queue: str = Field( + title='Queue', + description='name of queue; leave empty for single response', ) name: str = Field( title='Name', - description='name of the last image', - ) - threshold: float = Field( - title='Threshold', - description='The threshold used for the interrogate model.', + description='name to queue an image as; empty for the final response', ) class TaggerInterrogateResponse(BaseModel): """Interrogate response model""" - caption: Dict[str, float] = Field( + caption: Dict[str, Dict[str, float]] = Field( title='Caption', description='The generated captions for the image.' ) @@ -54,16 +39,3 @@ class TaggerInterrogatorsResponse(BaseModel): title='Models', description='' ) - - -class TaggerQueueImageResponse(BaseModel): - """Queue image response model""" - pass - - -class TaggerBatchResponse(BaseModel): - """Batch response model""" - captions: Dict[str, Dict[str, float]] = Field( - title='Captions', - description='The generated captions for the images.' - ) From 1284e6c4057a894539f64d843a01a98d46463d76 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Mon, 11 Sep 2023 23:59:26 +0200 Subject: [PATCH 48/78] Forgot to commit these before pushing --- tagger/api.py | 111 ++++++++++++++++++------------------------------ tagger/uiset.py | 5 +-- 2 files changed, 44 insertions(+), 72 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 2614966..b393a22 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -1,10 +1,9 @@ """API module for FastAPI""" -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional from threading import Lock from secrets import compare_digest import asyncio from collections import defaultdict -from itertools import cycle from modules import shared # pylint: disable=import-error from modules.api.api import decode_base64_to_image # pylint: disable=E0401 @@ -30,7 +29,7 @@ def __init__( self.app = app self.queue: Dict[str, asyncio.Queue] = {} - self.results = Dict[str, Dict[str, Dict[str, float]]] = {} + self.results: Dict[str, Dict[str, float]] = {} self.queue_lock = qlock self.running_batches: Dict[str, Dict[str, asyncio.Task]] = \ defaultdict(dict) @@ -67,52 +66,31 @@ def __init__( response_model=str, ) - async def process_model(self, model: str) -> None: - """Process a batch of images""" - res: Dict[str, Dict[str, float]] = defaultdict(dict) - while len(self.queue[model]) > 0: - skipped = 0 - for queue in self.queue[model]: - while True: - try: - queue_name, name, image, threshold = await queue.get_nowait() - except asyncio.QueueEmpty: - skipped += 1 - break # if empty move on to next queue for same model (if any) - if name == "": - # This is the end of the queue - self.results[queue_name] = res[queue_name] - del res[queue_name] - skipped += 1 - break - # No queue or queued_name, processes instead of queuing - res[queue_name][name] = await self.endpoint_interrogate( - models.TaggerInterrogateRequest( - image=image, - model=model, - threshold=threshold, - queue="", - queued_name="" - ) - ) - if skipped == len(self.queue[model]): - # if all queues for this model are empty, postpone interrogration - # for this model and do other models first - break + async def batch_process(self, model: str) -> None: + done: Dict[str, bool] = {model: False} - async def batch_process(self) -> None: while len(self.queue) > 0: - for model in self.queue: - if model not in self.running_batches: - self.running_batches[model] = asyncio.create_task( - self.process_model(model) - ) - elif len(self.queue[model]) == 0: - await self.running_batches[model] - del self.running_batches[model] + while self.queue[model].qsize() > 0: + with self.queue_lock: + q, n, i, t = self.queue[model].get_nowait() + if n != "": + # Leaving queue and _name empty to process, not queue + self.results[q][n] = await self.endpoint_interrogate( + models.TaggerInterrogateRequest( + image=i, + model=model, + threshold=t, + queue="", + queued_name="" + ) + ) + else: + # This is the end of the queue + done[model] = True + if done[model]: + del self.queue[model] await asyncio.sleep(0.1) - def auth(self, creds: Optional[HTTPBasicCredentials] = None): if creds is None: creds = Depends(HTTPBasic()) @@ -149,21 +127,13 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): res: Dict[str, Dict[str, float]] = defaultdict(dict) m, q, n = (req.model, req.queue, req.name) - if n == '' and q in self.running_batches[m]: - # wait for batch to finish - res = self.running_batches[m][q].result() - del self.running_batches[m][q] - - elif q != '': - # queueing interrogation of the image - with self.queue_lock: - # check before populating the queue - initialize_runner = len(self.queue) == 0 - - # create to retreive data when a queue is finished - if q not in self.queue[m]: - self.queue[m][q] = asyncio.Queue() + with self.queue_lock: + if n == '' and q in self.running_batches[m]: + # wait for batch to finish + res = self.running_batches[m][q].result() + del self.running_batches[m][q] + elif q != '': # clobber existing image if n in self.images[m][q]: i = 0 @@ -171,18 +141,21 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): i = i + 1 n = f'{n}.{i}' - # add image to queue - self.queue[m][q].put_nowait((n, image, req.threshold)) - if initialize_runner: - self.runner = asyncio.create_task(self.batch_process()) - else: - interrogator = utils.interrogators[m] - with self.queue_lock: + if m in self.queue: + # add image to queue + self.queue[m].put_nowait((q, n, image, req.threshold)) + else: + self.queue[m] = asyncio.Queue() + self.queue[m].put_nowait((q, n, image, req.threshold)) + self.runner = asyncio.create_task(self.batch_process(m)) + else: + interrogator = utils.interrogators[m] rating, tag = interrogator.interrogate(image) - if req.threshold > 0.0: - tag = {k: v for k, v in tag.items() if v > req.threshold} - res[n] = {**rating, **tag} + res[n] = {**rating} + for k, v in tag.items(): + if v > req.threshold: + res[n][k] = v return models.TaggerInterrogateResponse(caption=res) diff --git a/tagger/uiset.py b/tagger/uiset.py index d887875..d078508 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -319,9 +319,8 @@ def update_add(cls, add: str) -> None: cls.test_add(tag, 'add', ['exclude', 'search']) # silently raise count threshold to avoid issue in apply_filters - count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100) - if len(cls.add_tags) > count_threshold: - shared.opts.tagger_count_threshold = len(cls.add_tags) + if len(cls.add_tags) > cls.count_threshold: + cls.count_threshold = len(cls.add_tags) @staticmethod def compile_rex(rex: str) -> Optional: From bae45dcf42bf4479c1af807336188fc2b7bdd3d1 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 12 Sep 2023 00:19:17 +0200 Subject: [PATCH 49/78] something like this --- tagger/api.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index b393a22..ceb5e24 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -29,10 +29,8 @@ def __init__( self.app = app self.queue: Dict[str, asyncio.Queue] = {} - self.results: Dict[str, Dict[str, float]] = {} + self.results: Dict[str, Dict[str, Dict[str, float]]] = {} self.queue_lock = qlock - self.running_batches: Dict[str, Dict[str, asyncio.Task]] = \ - defaultdict(dict) self.runner: Optional[asyncio.Task] = None self.prefix = prefix @@ -40,11 +38,6 @@ def __init__( self.images: Dict[str, Dict[str, Dict[str, tuple[object, float]]]] = \ defaultdict(lambda: defaultdict(dict)) - self.finished_batches: Dict[str, Dict[str, asyncio.Task]] = \ - defaultdict(dict) - - self.reached_end: Dict[str, Dict[str, bool]] = defaultdict(dict) - self.add_api_route( 'interrogate', self.endpoint_interrogate, @@ -128,12 +121,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): m, q, n = (req.model, req.queue, req.name) with self.queue_lock: - if n == '' and q in self.running_batches[m]: - # wait for batch to finish - res = self.running_batches[m][q].result() - del self.running_batches[m][q] - - elif q != '': + if q != '': # clobber existing image if n in self.images[m][q]: i = 0 @@ -148,6 +136,8 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): self.queue[m] = asyncio.Queue() self.queue[m].put_nowait((q, n, image, req.threshold)) self.runner = asyncio.create_task(self.batch_process(m)) + if n == '': + res = self.results[q] else: interrogator = utils.interrogators[m] rating, tag = interrogator.interrogate(image) From f6f1dc68cc5b88290cd435b1f4f6ed00a2695318 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 12 Sep 2023 16:15:59 +0200 Subject: [PATCH 50/78] If queue is empty, there is a single interrogation with response. If not, if name_in_queue is given, the interrogation is queued under that name. The response is the number of all processed for all active queues & all models. If name_in_queue is empty, the queue is marked as finished, A response is awaited for remaining interrogatioons. The response is only for this queue. at least that's how it's supposed to work. no testing yet, except compile. --- tagger/api.py | 122 +++++++++++++++++++++++++++---------------- tagger/api_models.py | 5 +- tagger/uiset.py | 6 +-- 3 files changed, 82 insertions(+), 51 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index ceb5e24..90cbb84 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -4,6 +4,7 @@ from secrets import compare_digest import asyncio from collections import defaultdict +from hashlib import sha256 from modules import shared # pylint: disable=import-error from modules.api.api import decode_base64_to_image # pylint: disable=E0401 @@ -29,14 +30,14 @@ def __init__( self.app = app self.queue: Dict[str, asyncio.Queue] = {} - self.results: Dict[str, Dict[str, Dict[str, float]]] = {} + self.res: Dict[str, Dict[str, Dict[str, float]]] = \ + defaultdict(dict) self.queue_lock = qlock self.runner: Optional[asyncio.Task] = None self.prefix = prefix - - self.images: Dict[str, Dict[str, Dict[str, tuple[object, float]]]] = \ - defaultdict(lambda: defaultdict(dict)) + self.running_batches: Dict[str, Dict[str, float]] = \ + defaultdict(lambda: defaultdict(int)) self.add_api_route( 'interrogate', @@ -59,30 +60,58 @@ def __init__( response_model=str, ) - async def batch_process(self, model: str) -> None: - done: Dict[str, bool] = {model: False} + async def add_to_queue(self, m, q, n='', i=None, th=0.0) -> Dict[ + str, Dict[str, float] + ]: + with self.queue_lock: + if m not in self.queue: + self.queue[m] = asyncio.Queue() + await self.queue[m].put((q, n, i, th)) + if i is not None: + if self.runner is None: + self.runner = asyncio.create_task(self.batch_process()) + # return how many interrogations are done so far per queue + return self.running_batches + # wait for the result to become available + while q in self.running_batches[m]: + await asyncio.sleep(0.1) + return self.res.pop(q) + async def batch_process(self) -> None: while len(self.queue) > 0: - while self.queue[model].qsize() > 0: - with self.queue_lock: - q, n, i, t = self.queue[model].get_nowait() + for m in self.queue: + # if zero the queue might just be pending + while self.queue[m].qsize() > 0: + with self.queue_lock: + q, n, i, t = self.queue[m].get_nowait() if n != "": - # Leaving queue and _name empty to process, not queue - self.results[q][n] = await self.endpoint_interrogate( + if self.running_batches[m][q] < 0: + print(f"Queue {q} is closed") + continue + self.running_batches[m][q] += 1.0 + # queue empty to process, not queue + self.res[m][n] = await self.endpoint_interrogate( models.TaggerInterrogateRequest( image=i, - model=model, + model=m, threshold=t, queue="", - queued_name="" + name_in_queue=n ) ) else: - # This is the end of the queue - done[model] = True - if done[model]: - del self.queue[model] - await asyncio.sleep(0.1) + # if there were any queries, mark it finished + del self.running_batches[m][q] + + for model in self.running_batches: + if len(self.running_batches[model]) == 0: + with self.queue_lock: + del self.queue[model] + else: + await asyncio.sleep(0.1) + + self.running_batches.clear() + self.runner = None def auth(self, creds: Optional[HTTPBasicCredentials] = None): if creds is None: @@ -109,43 +138,44 @@ def add_api_route(self, path: str, endpoint: Callable, **kwargs): return self.app.add_api_route(path, endpoint, **kwargs) def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): - """ one file interrogation """ + """ one file interrogation, queueing, or batch results """ if req.image is None: raise HTTPException(404, 'Image not found') if req.model not in utils.interrogators: raise HTTPException(404, 'Model not found') + m, q, n = (req.model, req.queue, req.name_in_queue) + if n == '' and q != '': + # indicate the end of a queue + tup = (q, n, None, 0.0) + return asyncio.create_task(self.add_to_queue(m, tup)).result() + image = decode_base64_to_image(req.image) res: Dict[str, Dict[str, float]] = defaultdict(dict) - m, q, n = (req.model, req.queue, req.name) - - with self.queue_lock: - if q != '': - # clobber existing image - if n in self.images[m][q]: - i = 0 - while f'{n}.{i}' in self.images[m][q]: - i = i + 1 - n = f'{n}.{i}' - - if m in self.queue: - # add image to queue - self.queue[m].put_nowait((q, n, image, req.threshold)) - else: - self.queue[m] = asyncio.Queue() - self.queue[m].put_nowait((q, n, image, req.threshold)) - self.runner = asyncio.create_task(self.batch_process(m)) - if n == '': - res = self.results[q] - else: - interrogator = utils.interrogators[m] - rating, tag = interrogator.interrogate(image) - res[n] = {**rating} - for k, v in tag.items(): - if v > req.threshold: - res[n][k] = v + if q != '': + if m not in self.queue: + self.queue[m] = asyncio.Queue() + if n == '': + n = sha256(image.tobytes()).hexdigest() + elif f'{q}#{n}' in self.res[m]: + # clobber name if it's already in the queue + i = 0 + while f'{q}#{n}#{i}' in self.res[m]: + i += 1 + n = f'{q}#{n}#{i}' + # add image to queue + res = asyncio.create_task( + self.add_to_queue(m, q, n, image, req.threshold) + ).result() + else: + interrogator = utils.interrogators[m] + res[n], tag = interrogator.interrogate(image) + + for k, v in tag.items(): + if v > req.threshold: + res[n][k] = v return models.TaggerInterrogateResponse(caption=res) diff --git a/tagger/api_models.py b/tagger/api_models.py index 1f61881..b935ed7 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -19,9 +19,10 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): title='Queue', description='name of queue; leave empty for single response', ) - name: str = Field( + name_in_queue: str = Field( title='Name', - description='name to queue an image as; empty for the final response', + description='name to queue image as or use . leave empty to ' + 'retrieve the final response', ) diff --git a/tagger/uiset.py b/tagger/uiset.py index d078508..0d141f5 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -42,9 +42,9 @@ class IOData: last_path_mtimes = None base_dir = None output_root = None - paths = [] + paths: List[List[str]] = [] save_tags = True - err = set() + err: Set[str] = set() @classmethod def error_msg(cls) -> str: @@ -52,7 +52,7 @@ def error_msg(cls) -> str: "" @classmethod - def flip_save_tags(cls) -> callable: + def flip_save_tags(cls) -> Callable: def toggle(): cls.save_tags = not cls.save_tags return toggle From 83614c14cbadf438d0a6ed5afdf90a51835090ca Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 12 Sep 2023 18:45:34 +0200 Subject: [PATCH 51/78] add missing api_model defaults --- tagger/api_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tagger/api_models.py b/tagger/api_models.py index b935ed7..ec45d6d 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -14,15 +14,18 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): threshold: float = Field( title='Threshold', description='The threshold used for the interrogate model.', + default=1.0, ) queue: str = Field( title='Queue', description='name of queue; leave empty for single response', + default='', ) name_in_queue: str = Field( title='Name', description='name to queue image as or use . leave empty to ' 'retrieve the final response', + default='', ) From 0be994c0ff4f87895948249dc797714678b9d9a5 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 12 Sep 2023 19:08:43 +0200 Subject: [PATCH 52/78] no running event loop --- tagger/api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 90cbb84..83201f2 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -33,6 +33,7 @@ def __init__( self.res: Dict[str, Dict[str, Dict[str, float]]] = \ defaultdict(dict) self.queue_lock = qlock + self.loop = asyncio.get_event_loop() self.runner: Optional[asyncio.Task] = None self.prefix = prefix @@ -67,9 +68,10 @@ async def add_to_queue(self, m, q, n='', i=None, th=0.0) -> Dict[ if m not in self.queue: self.queue[m] = asyncio.Queue() await self.queue[m].put((q, n, i, th)) + if i is not None: if self.runner is None: - self.runner = asyncio.create_task(self.batch_process()) + self.runner = self.loop.create_task(self.batch_process()) # return how many interrogations are done so far per queue return self.running_batches # wait for the result to become available @@ -149,7 +151,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if n == '' and q != '': # indicate the end of a queue tup = (q, n, None, 0.0) - return asyncio.create_task(self.add_to_queue(m, tup)).result() + return self.loop.create_task(self.add_to_queue(m, tup)).result() image = decode_base64_to_image(req.image) res: Dict[str, Dict[str, float]] = defaultdict(dict) @@ -166,7 +168,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): i += 1 n = f'{q}#{n}#{i}' # add image to queue - res = asyncio.create_task( + res = self.loop.create_task( self.add_to_queue(m, q, n, image, req.threshold) ).result() else: From dfd7c82b31282a411eb83cd071d5458b838c82cf Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 12 Sep 2023 20:58:31 +0200 Subject: [PATCH 53/78] better threshold default (no threshold) The name may be empty, better ignore it for single-image. The deeper nested object is required, because the queued query requires a per name interrogation, so maybe better to separate tag from rating, for single. To allow distinction prepend ratings in batch query with 'rating:' --- tagger/api.py | 11 ++++++----- tagger/api_models.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 83201f2..ee1c9db 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -92,15 +92,16 @@ async def batch_process(self) -> None: continue self.running_batches[m][q] += 1.0 # queue empty to process, not queue - self.res[m][n] = await self.endpoint_interrogate( + res = await self.endpoint_interrogate( models.TaggerInterrogateRequest( image=i, model=m, threshold=t, - queue="", - name_in_queue=n ) ) + self.res[m][n] = res["tag"] + for k, v in res["rating"].items(): + self.res[m][n]["rating:"+k] = v else: # if there were any queries, mark it finished del self.running_batches[m][q] @@ -173,11 +174,11 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): ).result() else: interrogator = utils.interrogators[m] - res[n], tag = interrogator.interrogate(image) + res["rating"], tag = interrogator.interrogate(image) for k, v in tag.items(): if v > req.threshold: - res[n][k] = v + res["tag"][k] = v return models.TaggerInterrogateResponse(caption=res) diff --git a/tagger/api_models.py b/tagger/api_models.py index ec45d6d..1cd2dde 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -14,7 +14,7 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): threshold: float = Field( title='Threshold', description='The threshold used for the interrogate model.', - default=1.0, + default=0.0, ) queue: str = Field( title='Queue', From c1ce31e64156f955de7fa57a30f42f52b7236719 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Tue, 12 Sep 2023 22:55:34 +0200 Subject: [PATCH 54/78] first success with running using queue, but the queue is only started when the queue return is asked (by not providing a name) --- tagger/api.py | 48 +++++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index ee1c9db..a673727 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -99,9 +99,9 @@ async def batch_process(self) -> None: threshold=t, ) ) - self.res[m][n] = res["tag"] + self.res[q][n] = res["tag"] for k, v in res["rating"].items(): - self.res[m][n]["rating:"+k] = v + self.res[q][n]["rating:"+k] = v else: # if there were any queries, mark it finished del self.running_batches[m][q] @@ -140,6 +140,24 @@ def add_api_route(self, path: str, endpoint: Callable, **kwargs): Depends(self.auth)], **kwargs) return self.app.add_api_route(path, endpoint, **kwargs) + async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ + str, Dict[str, float] + ]: + """ queue an interrogation, or add to batch """ + if n == '': + return await self.add_to_queue(m, q) + image = decode_base64_to_image(i) + if n == '': + n = sha256(image.tobytes()).hexdigest() + elif f'{q}#{n}' in self.res[q]: + # clobber name if it's already in the queue + i = 0 + while f'{q}#{n}#{i}' in self.res[q]: + i += 1 + n = f'{q}#{n}#{i}' + # add image to queue + return await self.add_to_queue(m, q, n, image, t) + def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): """ one file interrogation, queueing, or batch results """ if req.image is None: @@ -149,31 +167,15 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): raise HTTPException(404, 'Model not found') m, q, n = (req.model, req.queue, req.name_in_queue) - if n == '' and q != '': - # indicate the end of a queue - tup = (q, n, None, 0.0) - return self.loop.create_task(self.add_to_queue(m, tup)).result() - - image = decode_base64_to_image(req.image) - res: Dict[str, Dict[str, float]] = defaultdict(dict) + res: Dict[str, Dict[str, float]] = {} if q != '': - if m not in self.queue: - self.queue[m] = asyncio.Queue() - if n == '': - n = sha256(image.tobytes()).hexdigest() - elif f'{q}#{n}' in self.res[m]: - # clobber name if it's already in the queue - i = 0 - while f'{q}#{n}#{i}' in self.res[m]: - i += 1 - n = f'{q}#{n}#{i}' - # add image to queue - res = self.loop.create_task( - self.add_to_queue(m, q, n, image, req.threshold) - ).result() + res = asyncio.run(self.queue_interrogation(m, q, n, req.image, + req.threshold)) else: + image = decode_base64_to_image(req.image) interrogator = utils.interrogators[m] + res = {"tag": {}, "rating": {}} res["rating"], tag = interrogator.interrogate(image) for k, v in tag.items(): From 58f76e8a4a41280b522d0c8c7d9e192964ddd9b9 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 13 Sep 2023 18:16:56 +0200 Subject: [PATCH 55/78] This does something more, but still some tasks do not complete. --- tagger/api.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index a673727..97a0b19 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -33,7 +33,7 @@ def __init__( self.res: Dict[str, Dict[str, Dict[str, float]]] = \ defaultdict(dict) self.queue_lock = qlock - self.loop = asyncio.get_event_loop() + self.loop = None self.runner: Optional[asyncio.Task] = None self.prefix = prefix @@ -69,10 +69,12 @@ async def add_to_queue(self, m, q, n='', i=None, th=0.0) -> Dict[ self.queue[m] = asyncio.Queue() await self.queue[m].put((q, n, i, th)) - if i is not None: - if self.runner is None: + if n != '': + if self.loop is None: + self.loop = asyncio.get_event_loop() self.runner = self.loop.create_task(self.batch_process()) # return how many interrogations are done so far per queue + print("add_to_queue: " + repr(self.running_batches)) return self.running_batches # wait for the result to become available while q in self.running_batches[m]: @@ -145,7 +147,9 @@ async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ ]: """ queue an interrogation, or add to batch """ if n == '': - return await self.add_to_queue(m, q) + res = await self.add_to_queue(m, q) + print("queue_interrogation1: " + repr(res)) + return res image = decode_base64_to_image(i) if n == '': n = sha256(image.tobytes()).hexdigest() @@ -156,7 +160,9 @@ async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ i += 1 n = f'{q}#{n}#{i}' # add image to queue - return await self.add_to_queue(m, q, n, image, t) + res = await self.add_to_queue(m, q, n, image, t) + print("queue_interrogation2: " + repr(res)) + return res def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): """ one file interrogation, queueing, or batch results """ From 0a29c4fad50059d26f5b7ce79f017c4c2179f6d3 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 13 Sep 2023 22:07:07 +0200 Subject: [PATCH 56/78] asyncio.run already handles the event loop variable i was shadowed then in order corrected the issues: only before interrogation the image needs to be decoded. TypeError: object TaggerInterrogateResponse can't be used in 'await' expression TypeError: 'TaggerInterrogateResponse' object is not subscriptable and now there's something not awaited.. (no completion) --- tagger/api.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 97a0b19..69272eb 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -33,7 +33,6 @@ def __init__( self.res: Dict[str, Dict[str, Dict[str, float]]] = \ defaultdict(dict) self.queue_lock = qlock - self.loop = None self.runner: Optional[asyncio.Task] = None self.prefix = prefix @@ -70,9 +69,8 @@ async def add_to_queue(self, m, q, n='', i=None, th=0.0) -> Dict[ await self.queue[m].put((q, n, i, th)) if n != '': - if self.loop is None: - self.loop = asyncio.get_event_loop() - self.runner = self.loop.create_task(self.batch_process()) + if self.runner is None: + self.runner = await asyncio.create_task(self.batch_process()) # return how many interrogations are done so far per queue print("add_to_queue: " + repr(self.running_batches)) return self.running_batches @@ -89,20 +87,19 @@ async def batch_process(self) -> None: with self.queue_lock: q, n, i, t = self.queue[m].get_nowait() if n != "": - if self.running_batches[m][q] < 0: - print(f"Queue {q} is closed") - continue self.running_batches[m][q] += 1.0 # queue empty to process, not queue - res = await self.endpoint_interrogate( + res = self.endpoint_interrogate( models.TaggerInterrogateRequest( image=i, model=m, threshold=t, + name_in_queue=n, + queue='' ) ) - self.res[q][n] = res["tag"] - for k, v in res["rating"].items(): + self.res[q][n] = res.caption["tag"] + for k, v in res.caption["rating"].items(): self.res[q][n]["rating:"+k] = v else: # if there were any queries, mark it finished @@ -150,17 +147,16 @@ async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ res = await self.add_to_queue(m, q) print("queue_interrogation1: " + repr(res)) return res - image = decode_base64_to_image(i) if n == '': - n = sha256(image.tobytes()).hexdigest() + n = sha256(i).hexdigest() elif f'{q}#{n}' in self.res[q]: # clobber name if it's already in the queue - i = 0 - while f'{q}#{n}#{i}' in self.res[q]: - i += 1 - n = f'{q}#{n}#{i}' + j = 0 + while f'{q}#{n}#{j}' in self.res[q]: + j += 1 + n = f'{q}#{n}#{j}' # add image to queue - res = await self.add_to_queue(m, q, n, image, t) + res = await self.add_to_queue(m, q, n, i, t) print("queue_interrogation2: " + repr(res)) return res @@ -177,7 +173,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if q != '': res = asyncio.run(self.queue_interrogation(m, q, n, req.image, - req.threshold)) + req.threshold)) else: image = decode_base64_to_image(req.image) interrogator = utils.interrogators[m] From 05901adce369d24bd2af3f0c5ddd6cad1bb368f9 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 13 Sep 2023 23:20:56 +0200 Subject: [PATCH 57/78] although an interrogation does not complete, it is queued and executed because the finish works, even for two images. Concurrent queues still seem to fail too, however. --- tagger/api.py | 60 ++++++++++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 69272eb..79bc69c 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -33,6 +33,7 @@ def __init__( self.res: Dict[str, Dict[str, Dict[str, float]]] = \ defaultdict(dict) self.queue_lock = qlock + self.tasks: Dict[str, asyncio.Task] = {} self.runner: Optional[asyncio.Task] = None self.prefix = prefix @@ -68,15 +69,34 @@ async def add_to_queue(self, m, q, n='', i=None, th=0.0) -> Dict[ self.queue[m] = asyncio.Queue() await self.queue[m].put((q, n, i, th)) - if n != '': - if self.runner is None: - self.runner = await asyncio.create_task(self.batch_process()) - # return how many interrogations are done so far per queue - print("add_to_queue: " + repr(self.running_batches)) - return self.running_batches - # wait for the result to become available - while q in self.running_batches[m]: - await asyncio.sleep(0.1) + if self.runner is None: + self.runner = await asyncio.create_task(self.batch_process()) + + return await self.tasks[q+"\t"+n] + + async def do_queued_interrogation(self, m, q, n, i, t) -> Dict[ + str, Dict[str, float] + ]: + self.running_batches[m][q] += 1.0 + # queue empty to process, not queue + res = self.endpoint_interrogate( + models.TaggerInterrogateRequest( + image=i, + model=m, + threshold=t, + name_in_queue=n, + queue='' + ) + ) + self.res[q][n] = res.caption["tag"] + for k, v in res.caption["rating"].items(): + self.res[q][n]["rating:"+k] = v + return self.running_batches + + async def finish_queue(self, m, q) -> Dict[str, Dict[str, float]]: + with self.queue_lock: + if q in self.running_batches[m]: + del self.running_batches[m][q] return self.res.pop(q) async def batch_process(self) -> None: @@ -86,24 +106,10 @@ async def batch_process(self) -> None: while self.queue[m].qsize() > 0: with self.queue_lock: q, n, i, t = self.queue[m].get_nowait() - if n != "": - self.running_batches[m][q] += 1.0 - # queue empty to process, not queue - res = self.endpoint_interrogate( - models.TaggerInterrogateRequest( - image=i, - model=m, - threshold=t, - name_in_queue=n, - queue='' - ) - ) - self.res[q][n] = res.caption["tag"] - for k, v in res.caption["rating"].items(): - self.res[q][n]["rating:"+k] = v - else: - # if there were any queries, mark it finished - del self.running_batches[m][q] + self.tasks[q+"\t"+n] = asyncio.create_task( + self.do_queued_interrogation(m, q, n, i, t) if n != "" + else self.finish_queue(m, q) + ) for model in self.running_batches: if len(self.running_batches[model]) == 0: From 6407d4861a02f8952f1cfd50aaadfe126056f438 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Fri, 15 Sep 2023 21:38:59 +0200 Subject: [PATCH 58/78] (off-topic) allow grepping from stdin --- shell_scripts/model_grep.sh | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/shell_scripts/model_grep.sh b/shell_scripts/model_grep.sh index 7b21a52..616b22d 100644 --- a/shell_scripts/model_grep.sh +++ b/shell_scripts/model_grep.sh @@ -36,10 +36,15 @@ if [ ! -e "${utilpath}/safetensors_util.py" ]; then exit 1 fi -if [ -z "$1" ]; then - echo "Usage: $0 [-p /path/to/stable-diffusion-webui/models/Lora (default: .)] " - exit 1 +if [ -n "$1" ]; then + ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | + sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -o -E '\"[^\"]*${1}[^\"]*\": [0-9]+'| sed 's~^~'{}':~p'" +else + tmp=$(mktemp) + sed 's/^/"[^\"]*/;s/$/[^\"]*": [0-9]+/' < /dev/stdin > $tmp + ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | + sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -oE -f $tmp | sed 's~^~'{}':~p'" + echo rm $tmp fi -ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | -sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -o -E '\"[^\"]*${1}[^\"]*\": [0-9]+'| sed 's~^~'{}':~p'" + From 803ef56e5d4180ad4268f54d79d93c34a9bd57e1 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Fri, 15 Sep 2023 21:40:27 +0200 Subject: [PATCH 59/78] finally some progress --- tagger/api.py | 71 +++++++++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 79bc69c..916310f 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -61,17 +61,20 @@ def __init__( response_model=str, ) - async def add_to_queue(self, m, q, n='', i=None, th=0.0) -> Dict[ + async def add_to_queue(self, m, q, n='', i=None, t=0.0) -> Dict[ str, Dict[str, float] ]: - with self.queue_lock: - if m not in self.queue: - self.queue[m] = asyncio.Queue() - await self.queue[m].put((q, n, i, th)) + if m not in self.queue: + self.queue[m] = asyncio.Queue() + # loop = asyncio.get_running_loop() + # asyncio.run_coroutine_threadsafe( + task = asyncio.create_task(self.queue[m].put((q, n, i, t))) + # , loop) if self.runner is None: - self.runner = await asyncio.create_task(self.batch_process()) - + loop = asyncio.get_running_loop() + asyncio.ensure_future(self.batch_process(), loop=loop) + await task return await self.tasks[q+"\t"+n] async def do_queued_interrogation(self, m, q, n, i, t) -> Dict[ @@ -94,18 +97,24 @@ async def do_queued_interrogation(self, m, q, n, i, t) -> Dict[ return self.running_batches async def finish_queue(self, m, q) -> Dict[str, Dict[str, float]]: - with self.queue_lock: - if q in self.running_batches[m]: - del self.running_batches[m][q] - return self.res.pop(q) + if q in self.running_batches[m]: + del self.running_batches[m][q] + if q in self.res: + return self.res.pop(q) + return self.running_batches async def batch_process(self) -> None: + # loop = asyncio.get_running_loop() while len(self.queue) > 0: for m in self.queue: # if zero the queue might just be pending - while self.queue[m].qsize() > 0: - with self.queue_lock: + while True: + try: + # q, n, i, t = asyncio.run_coroutine_threadsafe( + # self.queue[m].get_nowait(), loop).result() q, n, i, t = self.queue[m].get_nowait() + except asyncio.QueueEmpty: + break self.tasks[q+"\t"+n] = asyncio.create_task( self.do_queued_interrogation(m, q, n, i, t) if n != "" else self.finish_queue(m, q) @@ -113,8 +122,7 @@ async def batch_process(self) -> None: for model in self.running_batches: if len(self.running_batches[model]) == 0: - with self.queue_lock: - del self.queue[model] + del self.queue[model] else: await asyncio.sleep(0.1) @@ -150,20 +158,20 @@ async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ ]: """ queue an interrogation, or add to batch """ if n == '': - res = await self.add_to_queue(m, q) - print("queue_interrogation1: " + repr(res)) - return res - if n == '': - n = sha256(i).hexdigest() - elif f'{q}#{n}' in self.res[q]: - # clobber name if it's already in the queue - j = 0 - while f'{q}#{n}#{j}' in self.res[q]: - j += 1 - n = f'{q}#{n}#{j}' - # add image to queue - res = await self.add_to_queue(m, q, n, i, t) - print("queue_interrogation2: " + repr(res)) + task = asyncio.create_task(self.add_to_queue(m, q)) + else: + if n == '': + n = sha256(i).hexdigest() + elif f'{q}#{n}' in self.res[q]: + # clobber name if it's already in the queue + j = 0 + while f'{q}#{n}#{j}' in self.res[q]: + j += 1 + n = f'{q}#{n}#{j}' + # add image to queue + task = asyncio.create_task(self.add_to_queue(m, q, n, i, t)) + res = await task + print( "queue_interrogation: " + repr(res)) return res def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): @@ -179,12 +187,13 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if q != '': res = asyncio.run(self.queue_interrogation(m, q, n, req.image, - req.threshold)) + req.threshold), debug =True) else: image = decode_base64_to_image(req.image) interrogator = utils.interrogators[m] res = {"tag": {}, "rating": {}} - res["rating"], tag = interrogator.interrogate(image) + with self.queue_lock: + res["rating"], tag = interrogator.interrogate(image) for k, v in tag.items(): if v > req.threshold: From b057db1a7e07c29f3989bbd252db799b2de4a84f Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Fri, 15 Sep 2023 22:14:19 +0200 Subject: [PATCH 60/78] fix name clobber issue --- tagger/api.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 916310f..c322190 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -162,17 +162,15 @@ async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ else: if n == '': n = sha256(i).hexdigest() - elif f'{q}#{n}' in self.res[q]: + elif n in self.res[q]: # clobber name if it's already in the queue j = 0 - while f'{q}#{n}#{j}' in self.res[q]: + while f'{n}#{j}' in self.res[q]: j += 1 - n = f'{q}#{n}#{j}' + n = f'{n}#{j}' # add image to queue task = asyncio.create_task(self.add_to_queue(m, q, n, i, t)) - res = await task - print( "queue_interrogation: " + repr(res)) - return res + return await task def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): """ one file interrogation, queueing, or batch results """ @@ -187,7 +185,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if q != '': res = asyncio.run(self.queue_interrogation(m, q, n, req.image, - req.threshold), debug =True) + req.threshold)) else: image = decode_base64_to_image(req.image) interrogator = utils.interrogators[m] From 54add7fa4d2b8d191298c8fd63c0a29e142adb2d Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Fri, 15 Sep 2023 23:13:12 +0200 Subject: [PATCH 61/78] prevent sha256 dup --- tagger/api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tagger/api.py b/tagger/api.py index c322190..9778971 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -162,12 +162,15 @@ async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ else: if n == '': n = sha256(i).hexdigest() + if n in self.res[q]: + return self.running_batches elif n in self.res[q]: # clobber name if it's already in the queue j = 0 while f'{n}#{j}' in self.res[q]: j += 1 n = f'{n}#{j}' + self.res[q][n] = {} # add image to queue task = asyncio.create_task(self.add_to_queue(m, q, n, i, t)) return await task From cc1768a8fb1c09982b7a35245f814a38ed07e756 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 00:21:00 +0200 Subject: [PATCH 62/78] For the first image, if no queue name is provided, generate a random queue name, not in use. --- tagger/api.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index 9778971..b9216bb 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -5,6 +5,8 @@ import asyncio from collections import defaultdict from hashlib import sha256 +import string +from random import choices from modules import shared # pylint: disable=import-error from modules.api.api import decode_base64_to_image # pylint: disable=E0401 @@ -81,13 +83,13 @@ async def do_queued_interrogation(self, m, q, n, i, t) -> Dict[ str, Dict[str, float] ]: self.running_batches[m][q] += 1.0 - # queue empty to process, not queue + # queue and name empty to process, not queue res = self.endpoint_interrogate( models.TaggerInterrogateRequest( image=i, model=m, threshold=t, - name_in_queue=n, + name_in_queue='', queue='' ) ) @@ -186,9 +188,17 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): m, q, n = (req.model, req.queue, req.name_in_queue) res: Dict[str, Dict[str, float]] = {} - if q != '': + if q != '' or n != '': + if q == '': + # generate a random queue name, not in use + while True: + q = ''.join(choices(string.ascii_uppercase + + string.digits, k=8)) + if q not in self.queue: + break + print(f'WD14 tagger api generated queue name: {q}') res = asyncio.run(self.queue_interrogation(m, q, n, req.image, - req.threshold)) + req.threshold), debug=True) else: image = decode_base64_to_image(req.image) interrogator = utils.interrogators[m] From 45799518406929f2fd1e1996cf7db41b4f98d2f2 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 01:01:57 +0200 Subject: [PATCH 63/78] move bash_scripts/tag_based_image_dedup.sh --- {bash_scripts => shell_scripts}/tag_based_image_dedup.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {bash_scripts => shell_scripts}/tag_based_image_dedup.sh (100%) diff --git a/bash_scripts/tag_based_image_dedup.sh b/shell_scripts/tag_based_image_dedup.sh similarity index 100% rename from bash_scripts/tag_based_image_dedup.sh rename to shell_scripts/tag_based_image_dedup.sh From 34759d6d1c70faf08e956a5e094713713e5c2419 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 01:33:26 +0200 Subject: [PATCH 64/78] update chagelog --- CHANGELOG.md | 44 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e701c0..baa87be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,46 @@ -# v1.1.2 (2023-08-26) + +Api changes: +Image interrogation via api receives two extra parameters; empty strings by +default. `queue`: the name for a queue, which could be e.g. the person or +subject name. You can leave it empty for the first interrogation, then the +response will que in a new auto-generated unique name, listed in the response. + +Make sure you use this same name as queue, for all interrogations that you want +to be grouped together. The second parameter is `name_in_queue`: the name for +that particular image that is being queued, e.g. a file name. + +If both queue and name are empty, there is a single interrogation with response, +which includes nested objects "ratings" and "tags", so: +`{"ratings": {"sensitive": 0.5, ..}, "tags": {"tag1": 0.5, ..}}` + +If neither name nor queue are empty, the interrogation is queued under that name. +If already in queue, that name is changed - clobbered - with #. An exception is if +the given name is in which case an image checksum will be used instead of +a name. Duplicates are ignored. + +During queuing, the response is the number of all processed interrogations for all +active queues. + +If name_in_queue is empty, but queue is not, that particular queue is finalized, +A response is awaited for remaining interrogations in this queue (if any still). +The response, only for this queue, is an object with the name_in_queue as key, +and the tag with weights contained. Ratings have ther tag name prefixed with +"rating:". Example: +`{"name_in_queue": {"rating:sensitive": 0.5, "tag1": 0.5, ..}}` + +Fix in absence of tensrflow_io +Fix deprecation warning +Added three scripts in shell scripts under shell_scripts: + * A bash script to generate per safetensors file the fraction of images + that the model was trained on that was tagged with particular tokens. + * A python script to compare the interrogation results (read from db.json) + and find the top -c safetensors files that contain similar weights (or at + least, that was the intention, there may be better algorithms to compare, + but it seems to do the job). + * And finally a model_grep script which listts the tags and number of trained + images in a safetensors model. + +# v1.1.2 c9f8efd (2023-08-26) Explain recursive path usage better in ui Fix sending tags via buttons to txt2img and img2img From 0ef398abf33aa12180b845329faa4d1a5400598a Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 01:46:45 +0200 Subject: [PATCH 65/78] bump version --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index baa87be..c4c5000 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Image interrogation via api receives two extra parameters; empty strings by default. `queue`: the name for a queue, which could be e.g. the person or subject name. You can leave it empty for the first interrogation, then the response will que in a new auto-generated unique name, listed in the response. +# v1.2.0 (2023-09-16) Make sure you use this same name as queue, for all interrogations that you want to be grouped together. The second parameter is `name_in_queue`: the name for From cc4b4c2a1cb9d09937c100543cd37314b849570b Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 18:47:00 +0200 Subject: [PATCH 66/78] gvi # Conflicts: # tagger/settings.py # Conflicts: # tagger/interrogator.py # tagger/utils.py --- tagger/interrogator.py | 5 +++-- tagger/settings.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 1ae5cc2..1716322 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -3,9 +3,9 @@ from pathlib import Path import io import json +import inspect from re import match as re_match from jsonschema import validate -import inspect from platform import uname from typing import Tuple, List, Dict, Callable from pandas import read_csv @@ -14,7 +14,6 @@ from tqdm import tqdm from huggingface_hub import hf_hub_download -from prload import root_dir from modules import shared from tagger import settings # pylint: disable=import-error from tagger.uiset import QData, IOData # pylint: disable=import-error @@ -452,6 +451,8 @@ def __init__( self.model_path = model_path self.tags_path = tags_path self.model = None + self.local_model = None + self.local_tags = None # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse # and pass only the supported args. diff --git a/tagger/settings.py b/tagger/settings.py index 9cd2200..979e880 100644 --- a/tagger/settings.py +++ b/tagger/settings.py @@ -10,8 +10,10 @@ DEFAULT_OFF = '[name].[output_extension]' -HF_CACHE = os.environ.get('HF_HOME', os.environ.get('HUGGINGFACE_HUB_CACHE', - str(os.path.join(shared.models_path, 'interrogators')))) +HF_CACHE = os.environ.get( + 'HUGGINGFACE_HUB_CACHE', # defaults to "$HF_HOME/hub" + str(os.path.join(shared.models_path, 'interrogators'))) + def slider_wrapper(value, elem_id, **kwargs): # required or else gradio will throw errors @@ -124,11 +126,11 @@ def on_ui_settings(): ) # see huggingface_hub guides/manage-cache shared.opts.add_option( - key='tagger_hf_cache_dir', + key='tagger_hf_hub_down_opts', info=shared.OptionInfo( - HF_CACHE, - label='HuggingFace cache directory, ' - 'see huggingface_hub guides/manage-cache', + str(f'cache_dir="{HF_CACHE}"'), + label='HuggingFace parameters, Comma delimited: arg=value, ' + 'see huggingface_hub docs for available or leave alone.', section=section, ), ) From 9dffaad44a7f466903813abbca46efa12ea6e779 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 18:51:23 +0200 Subject: [PATCH 67/78] add interrogators.json move refresh to interrogator as a static, and pick up the configured interrogators there. presets in tagger/presets.py and tagger/utils.py can go. write info alongside model so we can check its up to date status # Conflicts: # tagger/ui.py # tagger/utils.py # Conflicts: # tagger/interrogator.py # tagger/ui.py # Conflicts: # tagger/api.py # tagger/ui.py # Conflicts: # tagger/interrogator.py # tagger/preset.py --- tagger/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tagger/api.py b/tagger/api.py index b9216bb..74ee0d2 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -182,7 +182,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if req.image is None: raise HTTPException(404, 'Image not found') - if req.model not in utils.interrogators: + if req.model not in Interrogator.entries.keys(): raise HTTPException(404, 'Model not found') m, q, n = (req.model, req.queue, req.name_in_queue) @@ -201,7 +201,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): req.threshold), debug=True) else: image = decode_base64_to_image(req.image) - interrogator = utils.interrogators[m] + interrogator = Interrogator.entries[m] res = {"tag": {}, "rating": {}} with self.queue_lock: res["rating"], tag = interrogator.interrogate(image) @@ -214,7 +214,7 @@ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): def endpoint_interrogators(self): return models.TaggerInterrogatorsResponse( - models=list(utils.interrogators.keys()) + models=list(Interrogator.entries.keys()) ) def endpoint_unload_interrogators(self): From 571c56c50a26917610c85939377904193a0fb1d3 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sat, 16 Sep 2023 17:11:01 +0200 Subject: [PATCH 68/78] broken --- interrogators.json | 69 ------------------------------------------ tagger/interrogator.py | 1 + 2 files changed, 1 insertion(+), 69 deletions(-) delete mode 100644 interrogators.json diff --git a/interrogators.json b/interrogators.json deleted file mode 100644 index ce55031..0000000 --- a/interrogators.json +++ /dev/null @@ -1,69 +0,0 @@ -{ - "mld-caformer.dec-5-97527" : { - "class" : "MLDanbooruInterrogator", - "repo_specs" : { - "model_path" : "ml_caformer_m36_dec-5-97527.onnx", - "name" : "ML-Danbooru Caformer dec-5-97527", - "repo_id" : "deepghs/ml-danbooru-onnx" - } - }, - "mld-tresnetd.6-30000" : { - "class" : "MLDanbooruInterrogator", - "repo_specs" : { - "model_path" : "TResnet-D-FLq_ema_6-30000.onnx", - "name" : "ML-Danbooru TResNet-D 6-30000", - "repo_id" : "deepghs/ml-danbooru-onnx" - } - }, - "wd-v1-4-moat-tagger.v2" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 moat tagger v2", - "repo_id" : "SmilingWolf/wd-v1-4-moat-tagger-v2" - } - }, - "wd14-convnext.v1" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ConvNeXT v1", - "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger" - } - }, - "wd14-convnext.v2" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ConvNeXT v2", - "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger-v2" - } - }, - "wd14-convnextv2.v1" : { - "remark" : "the repo_id name is misleading, but it's v1", - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ConvNeXTV2 v1", - "repo_id" : "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" - } - }, - "wd14-swinv2-v1" : { - "remark" : "the repo_id name is misleading, but it's v1", - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 SwinV2 v1", - "repo_id" : "SmilingWolf/wd-v1-4-swinv2-tagger-v2" - } - }, - "wd14-vit.v1" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ViT v1", - "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger" - } - }, - "wd14-vit.v2" : { - "class" : "WaifuDiffusionInterrogator", - "repo_specs" : { - "name" : "WD14 ViT v2", - "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger-v2" - } - } -} diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 1716322..ba70ec1 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -3,6 +3,7 @@ from pathlib import Path import io import json +from jsonschema import validate, ValidationError import inspect from re import match as re_match from jsonschema import validate From 0880cd9d7fab555a61be8fac1a4567b88b81cd67 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 6 Aug 2023 13:18:51 +0200 Subject: [PATCH 69/78] cleanups --- tagger/interrogator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index ba70ec1..d974e18 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -3,7 +3,7 @@ from pathlib import Path import io import json -from jsonschema import validate, ValidationError +from jsonschema import validate import inspect from re import match as re_match from jsonschema import validate From b5648582d02fa8f0dcc4695dfda8ad6479bef75d Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 17 Sep 2023 18:43:19 +0200 Subject: [PATCH 70/78] Fixes, Tagger is loaded without initialization errors --- {defaults => default}/interrogators.json | 9 +- json_schema/interrogators_v1_schema.json | 71 +++++++++++++++ preload.py | 4 +- tagger/interrogator.py | 108 ++++++++++++----------- tagger/ui.py | 30 +++---- tagger/uiset.py | 20 +++-- 6 files changed, 159 insertions(+), 83 deletions(-) rename {defaults => default}/interrogators.json (80%) create mode 100644 json_schema/interrogators_v1_schema.json diff --git a/defaults/interrogators.json b/default/interrogators.json similarity index 80% rename from defaults/interrogators.json rename to default/interrogators.json index a6b71af..2787b6f 100644 --- a/defaults/interrogators.json +++ b/default/interrogators.json @@ -43,14 +43,11 @@ }, "DeepDanbooruInterrogator": { "deepdanbooru-v3-20211112-sgd-e28": { - "name": "DeepDanbooru v3 20211112 sgd e28", - "repo_id": "KichangKim/DeepDanbooru", - "filename": "https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip" + "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip" }, "deepdanbooru-v4-20200814-sgd-e30": { - "name": "DeepDanbooru v4 20200814 sgd e30", - "repo_id": "KichangKim/DeepDanbooru", - "filename": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip" + "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip" + } }, "MLDanbooruInterrogator": { "mld-caformer.dec-5-97527" : { diff --git a/json_schema/interrogators_v1_schema.json b/json_schema/interrogators_v1_schema.json new file mode 100644 index 0000000..51fc17c --- /dev/null +++ b/json_schema/interrogators_v1_schema.json @@ -0,0 +1,71 @@ +{ + "type": "object", + "properties": { + "MLDanbooruInterrogator": { + "type": "object", + "patternProperties": { + ".*": { "$ref": "#/$defs/huggingfaceInterrogator" } + } + }, + "WaifuDiffusionInterrogator": { + "type": "object", + "patternProperties": { + ".*": { "$ref": "#/$defs/huggingfaceInterrogator" } + } + }, + "DeepDanbooruInterrogator": { + "type": "object", + "patternProperties": { + ".*": { + "type": "object", + "properties": { + "zip": { "type": "string" } + }, + "additionalProperties": false + } + } + }, + "MLDanbooruInterrogator": { + "type": "object", + "patternProperties": { + ".*": { "$ref": "#/$defs/huggingfaceInterrogator" } + } + }, + "additionalProperties": false + }, + "$defs": { + "huggingfaceInterrogator": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "model_path": { "type": "string" }, + "repo_id": { "type": "string" }, + "filename": { "type": "string" }, + "subfolder": { "type": "string" }, + "repo_type": { "type": "string" }, + "revision": { "type": "string" }, + "endpoint": { "type": "string" }, + "library_name": { "type": "string" }, + "library_version": { "type": "string" }, + "cache_dir": { "type": "string" }, + "local_dir": { "type": "string" }, + "local_dir_use_symlinks": { "type": "string" }, + "user_agent": { "type": "string" }, + "force_download": { "type": "boolean" }, + "force_filename": { "type": "string" }, + "proxies": { "type": "string" }, + "etag_timeout": { "type": "number" }, + "resume_download": { "type": "boolean" }, + "token": { "type": "string" }, + "local_files_only": { "type": "boolean" }, + "legacy_cache_layout": { "type": "boolean" } + }, + "required": [ + "name", + "repo_id" + ], + "additionalProperties": false + } + } +} + diff --git a/preload.py b/preload.py index 92ff73f..6be1b60 100644 --- a/preload.py +++ b/preload.py @@ -1,7 +1,9 @@ """ Preload module for DeepDanbooru or onnxtagger. """ from argparse import ArgumentParser +from pathlib import Path + +root_dir = Path(__file__).parent -root_dir = Path(__file__).parent.parent def preload(parser: ArgumentParser): """ Preload module for DeepDanbooru or onnxtagger. """ diff --git a/tagger/interrogator.py b/tagger/interrogator.py index d974e18..54959eb 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -3,8 +3,10 @@ from pathlib import Path import io import json -from jsonschema import validate import inspect +import sys +import shutil +import requests from re import match as re_match from jsonschema import validate from platform import uname @@ -18,6 +20,7 @@ from modules import shared from tagger import settings # pylint: disable=import-error from tagger.uiset import QData, IOData # pylint: disable=import-error +from preload import root_dir from . import dbimutils # pylint: disable=import-error # noqa Its = settings.InterrogatorSettings @@ -45,6 +48,15 @@ print(f'== WD14 tagger {TF_DEVICE_NAME}, {uname()} ==') +ddp_path = shared.cmd_opts.deepdanbooru_projects_path +if ddp_path is None: + ddp_path = Path(shared.models_path, 'deepdanbooru') +os.makedirs(ddp_path, exist_ok=True) + +onnx_path = shared.cmd_opts.onnxtagger_path +if onnx_path is None: + onnx_path = Path(shared.models_path, 'TaggerOnnx') +os.makedirs(onnx_path, exist_ok=True) class Interrogator: """ Interrogator class for tagger """ @@ -105,37 +117,20 @@ def refresh(cls) -> List[str]: if not it_path.exists(): raise FileNotFoundError(f'{it_path} not found.') - raw = json.loads(it_path) - schema = root_dir.joinpath('json_schema', - 'interrogators_v1_schema.json') - validate(raw, json.loads(schema.read_text())) - - for class_name, it in raw.items(): - if class_name == "DeepDanbooruInterrogator": - It_type = DeepDanbooruInterrogator - elif class_name == "WaifuDiffusionInterrogator": - It_type = WaifuDiffusionInterrogator - elif class_name == "MLDanbooruInterrogator": - It_type = MLDanbooruInterrogator - else: - raise ValueError(f'Unimplemented: {it["class"]}') - for name, obj in it.items(): - if name not in obj: - obj[name] = name - cls.entries[name] = It_type(**obj) - - cls.entries[name] = It_type(**it["repo_specs"]) - - # load deepdanbooru project - ddp_path = shared.cmd_opts.deepdanbooru_projects_path - if ddp_path is None: - ddp_path = Path(shared.models_path, 'deepdanbooru') - onnx_path = shared.cmd_opts.onnx_path - if onnx_path is None: - onnx_path = Path(shared.models_path, 'TaggerOnnx') - os.makedirs(ddp_path, exist_ok=True) - os.makedirs(onnx_path, exist_ok=True) + raw = json.loads(it_path.read_text()) + schema = root_dir.joinpath('json_schema', + 'interrogators_v1_schema.json') + validate(raw, json.loads(schema.read_text())) + + for class_name, it in raw.items(): + if class_name[-12:] == "Interrogator": + It_type = getattr(sys.modules[__name__], class_name) + for name, obj in it.items(): + if "name" not in obj: + obj["name"] = name + cls.entries[name] = It_type(**obj) + # load deepdanbooru project for path in os.scandir(ddp_path): print(f"Scanning {path} as deepdanbooru project") if not path.is_dir(): @@ -162,32 +157,20 @@ def refresh(cls) -> List[str]: if len(onnx_files) != 1: print(f"Warning: {path}: multiple .onnx models => skipped") continue - local_path = Path(path, onnx_files[0].name) - csv = [x for x in os.scandir(path) if x.name.endswith('.csv')] - if len(csv) == 0: + for csv in os.scandir(path): + if csv.name.endswith('.csv') and "tag" in csv.name.lower() \ + or "select" in csv.name.lower(): + tags_path = Path(path, csv.name) + break + else: print(f"Warning: {path}: no selected tags .csv file, skipped") continue - def tag_select_csvs_up_front(k): - k = k.name.lower() - return -1 if "tag" in k or "select" in k else 1 - - csv.sort(key=tag_select_csvs_up_front) - tags_path = Path(path, csv[0]) - + local_path = Path(path, onnx_files[0].name) if path.name not in cls.entries: - if path.name == 'wd-v1-4-convnextv2-tagger-v2': - cls.entries[path.name] = WaifuDiffusionInterrogator( - path.name, - repo_id='SmilingWolf/SW-CV-ModelZoo' - ) - elif path.name == 'Z3D-E621-Convnext': - cls.entries[path.name] = WaifuDiffusionInterrogator( - 'Z3D-E621-Convnext') - else: - raise NotImplementedError(f"Add {path.name} resolution " - "similar to above here") + print(f"Warning: {path} not configured in interrogators.json") + continue cls.entries[path.name].local_model = str(local_path) cls.entries[path.name].local_tags = str(tags_path) @@ -345,12 +328,31 @@ def interrogate( class DeepDanbooruInterrogator(Interrogator): """ Interrogator for DeepDanbooru models """ - def __init__(self, name: str, project_path: os.PathLike) -> None: + def __init__(self, name: str, project_path=None, zip='') -> None: super().__init__(name) + + if project_path is None: + project_path = Path(ddp_path, name) + if not project_path.is_dir(): + if zip == '': + raise FileNotFoundError(f'{project_path} does not exist') + os.makedirs(project_path) + response = requests.get(zip, stream=True) + local_zip = project_path / zip.split('/')[-1] + with open(local_zip, 'wb') as f: + for data in tqdm(response.iter_content(), + desc=f'Downloading {zip}'): + f.write(data) + shutil.unpack_archive(local_zip, project_path) + os.remove(local_zip) + self.project_path = project_path self.model = None self.tags = None + def set_project_path(self, path: os.PathLike) -> None: + self.project_path = path + def load(self) -> None: print(f'Loading {self.name} from {str(self.project_path)}') diff --git a/tagger/ui.py b/tagger/ui.py index 094c437..c958f6a 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -187,7 +187,7 @@ def on_ui_tabs(): interactive=True, type="pil" ) - image_submit = gr.Button( + img_submit = gr.Button( value='Interrogate image', variant='primary' ) @@ -239,7 +239,7 @@ def on_ui_tabs(): with gr.Column(): # preset selector with gr.Row(variant='compact'): - available_presets = utils.preset.list() + available_presets = preset.list() selected_preset = gr.Dropdown( label='Preset', choices=available_presets, @@ -253,7 +253,7 @@ def on_ui_tabs(): ui.create_refresh_button( selected_preset, lambda: None, - lambda: {'choices': utils.preset.list()}, + lambda: {'choices': preset.list()}, 'refresh_preset' ) @@ -281,20 +281,20 @@ def on_ui_tabs(): value='Unload all interrogate models' ) with gr.Row(variant='compact'): - tag_input["add"] = utils.preset.component( + tag_input["add"] = preset.component( gr.Textbox, label='Additional tags (comma split)', elem_id='additional-tags' ) with gr.Row(variant='compact'): - threshold = utils.preset.component( + threshold = preset.component( gr.Slider, label='Weight threshold', minimum=0, maximum=1, value=QData.threshold ) - tag_frac_threshold = utils.preset.component( + tag_frac_threshold = preset.component( gr.Slider, label='Min tag fraction in batch and ' 'interrogations', @@ -303,34 +303,34 @@ def on_ui_tabs(): value=QData.tag_frac_threshold, ) with gr.Row(variant='compact'): - cumulative = utils.preset.component( + cumulative = preset.component( gr.Checkbox, label='Combine interrogations', value=False ) - unload_after = utils.preset.component( + unload_after = preset.component( gr.Checkbox, label='Unload model after running', value=False ) with gr.Row(variant='compact'): - tag_input["search"] = utils.preset.component( + tag_input["search"] = preset.component( gr.Textbox, label='Search tag, .. ->', elem_id='search-tags' ) - tag_input["replace"] = utils.preset.component( + tag_input["replace"] = preset.component( gr.Textbox, label='-> Replace tag, ..', elem_id='replace-tags' ) with gr.Row(variant='compact'): - tag_input["keep"] = utils.preset.component( + tag_input["keep"] = preset.component( gr.Textbox, label='Keep tag, ..', elem_id='keep-tags' ) - tag_input["exclude"] = utils.preset.component( + tag_input["exclude"] = preset.component( gr.Textbox, label='Exclude tag, ..', elem_id='exclude-tags' @@ -464,11 +464,11 @@ def on_ui_tabs(): [tag_input[tag] for tag in TAG_INPUTS] # interrogation events - image_submit.click(fn=wrap_gradio_gpu_call(on_interrogate_image_submit), - inputs=[image] + common_input, outputs=common_output) + img_submit.click(fn=wrap_gradio_gpu_call(on_interrogate_image_submit), + inputs=[image] + common_input, outputs=common_output) image.change(fn=wrap_gradio_gpu_call(on_interrogate_image), - inputs=[image] + common_input, outputs=common_output) + inputs=[image] + common_input, outputs=common_output) batch_submit.click(fn=wrap_gradio_gpu_call(on_interrogate), inputs=[input_glob, output_dir] + common_input, diff --git a/tagger/uiset.py b/tagger/uiset.py index 0d141f5..577d0ad 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -45,6 +45,7 @@ class IOData: paths: List[List[str]] = [] save_tags = True err: Set[str] = set() + base_dir_last: Optional[str] = None @classmethod def error_msg(cls) -> str: @@ -152,10 +153,10 @@ def set_batch_io(cls, paths: List[str]) -> None: """ set input and output paths for batch mode """ checked_dirs = set() cls.paths = [] - for path in paths: - path = Path(path) + for path_str in paths: + path = Path(path_str) if not cls.save_tags: - cls.paths.append([path, '', '']) + cls.paths.append([path_str, '', '']) continue # guess the output path @@ -175,6 +176,8 @@ def set_batch_io(cls, paths: List[str]) -> None: except (TypeError, ValueError): cls.err.add(msg) + if not cls.output_root: + raise ValueError('output_root not set') output_dir = cls.output_root.joinpath( *path.parts[base_dir_last_idx + 1:]).parent @@ -197,18 +200,19 @@ def set_batch_io(cls, paths: List[str]) -> None: class QData: """ Query data: contains parameters for the query """ - add_tags = [] + add_tags: List[str] = [] keep_tags = set() - exclude_tags = [] - search_tags = {} - replace_tags = [] + exclude_tags: List[str] = [] + search_tags: Dict[str, str] = {} + replace_tags: List[str] = [] threshold = 0.35 tag_frac_threshold = 0.05 count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100) # read from db.json, update with what should be written to db.json: json_db = None - weighed = (defaultdict(list), defaultdict(list)) + weighed: Tuple[Dict[str, list], Dict[str, list]] = \ + (defaultdict(list), defaultdict(list)) query = {} # representing the (cumulative) current interrogations From 06927ea9c64d2f5188019b4df2c242752f3abb33 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Sun, 17 Sep 2023 22:27:25 +0200 Subject: [PATCH 71/78] clean up --- tagger/interrogator.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 54959eb..680ebed 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -5,8 +5,8 @@ import json import inspect import sys -import shutil -import requests +from shutil import unpack_archive +from requests import get as http_get from re import match as re_match from jsonschema import validate from platform import uname @@ -326,6 +326,18 @@ def interrogate( raise NotImplementedError() +def download_and_extract(url: str, path: Path, rm_zip=True) -> None: + """ Download and extract a zip file """ + response = http_get(url, stream=True) + local_zip = path / url.split('/')[-1] + with open(local_zip, 'wb') as f: + for data in tqdm(response.iter_content(), desc=f'Downloading {url}'): + f.write(data) + unpack_archive(local_zip, path) + if rm_zip: + os.remove(local_zip) + + class DeepDanbooruInterrogator(Interrogator): """ Interrogator for DeepDanbooru models """ def __init__(self, name: str, project_path=None, zip='') -> None: @@ -337,22 +349,12 @@ def __init__(self, name: str, project_path=None, zip='') -> None: if zip == '': raise FileNotFoundError(f'{project_path} does not exist') os.makedirs(project_path) - response = requests.get(zip, stream=True) - local_zip = project_path / zip.split('/')[-1] - with open(local_zip, 'wb') as f: - for data in tqdm(response.iter_content(), - desc=f'Downloading {zip}'): - f.write(data) - shutil.unpack_archive(local_zip, project_path) - os.remove(local_zip) + download_and_extract(zip, project_path) self.project_path = project_path self.model = None self.tags = None - def set_project_path(self, path: os.PathLike) -> None: - self.project_path = path - def load(self) -> None: print(f'Loading {self.name} from {str(self.project_path)}') From 5d470efabf8acd1835408593dc19ea5b7b1f4179 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Mon, 18 Sep 2023 01:10:24 +0200 Subject: [PATCH 72/78] fix --- tagger/interrogator.py | 4 ++-- tagger/uiset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 680ebed..a893fba 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -456,8 +456,8 @@ def __init__( self.model_path = model_path self.tags_path = tags_path self.model = None - self.local_model = None - self.local_tags = None + self.local_model = '' + self.local_tags = '' # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse # and pass only the supported args. diff --git a/tagger/uiset.py b/tagger/uiset.py index 577d0ad..a7147fa 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -128,7 +128,7 @@ def update_input_glob(cls, input_glob: str) -> None: # interrogating in a directory with no pics, still flush the cache if len(path_mtimes) > 0 and cls.last_path_mtimes == path_mtimes: - print('No changed images') + # No (changed) images, keep the data return QData.clear(2) From 8bbad9398670733cfdf38f37e45f082c4a45591b Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Mon, 18 Sep 2023 21:32:32 +0200 Subject: [PATCH 73/78] clean up --- tagger/interrogator.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index a893fba..4e98a3b 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -132,10 +132,9 @@ def refresh(cls) -> List[str]: # load deepdanbooru project for path in os.scandir(ddp_path): - print(f"Scanning {path} as deepdanbooru project") if not path.is_dir(): - print(f"Warning: {path} is not a directory, skipped") continue + print(f"Scanning {path} as deepdanbooru project") if not Path(path, 'project.json').is_file(): print(f"Warning: {path} has no project.json, skipped") @@ -144,10 +143,9 @@ def refresh(cls) -> List[str]: cls.entries[path.name] = DeepDanbooruInterrogator(path.name, path) # scan for onnx models as well for path in os.scandir(onnx_path): - print(f"Scanning {path} as onnx model") if not path.is_dir(): - print(f"Warning: {path} is not a directory, skipped") continue + print(f"Scanning {path} as onnx model") onnx_files = [] for file_name in os.scandir(path): @@ -167,11 +165,10 @@ def refresh(cls) -> List[str]: print(f"Warning: {path}: no selected tags .csv file, skipped") continue - local_path = Path(path, onnx_files[0].name) if path.name not in cls.entries: print(f"Warning: {path} not configured in interrogators.json") continue - + local_path = Path(path, onnx_files[0].name) cls.entries[path.name].local_model = str(local_path) cls.entries[path.name].local_tags = str(tags_path) From 274143527ff95a80c09188ceba2f91b55e014ca6 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Mon, 18 Sep 2023 23:58:36 +0200 Subject: [PATCH 74/78] this seems to work again, requires json edits for custom models, though --- json_schema/interrogators_v1_schema.json | 22 ++++- tagger/interrogator.py | 103 ++++------------------- 2 files changed, 37 insertions(+), 88 deletions(-) diff --git a/json_schema/interrogators_v1_schema.json b/json_schema/interrogators_v1_schema.json index 51fc17c..ea643b8 100644 --- a/json_schema/interrogators_v1_schema.json +++ b/json_schema/interrogators_v1_schema.json @@ -1,4 +1,6 @@ { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Interrogator Configurations", "type": "object", "properties": { "MLDanbooruInterrogator": { @@ -38,6 +40,8 @@ "type": "object", "properties": { "name": { "type": "string" }, + "local_model": { "type": "string" }, + "local_tags": { "type": "string" }, "model_path": { "type": "string" }, "repo_id": { "type": "string" }, "filename": { "type": "string" }, @@ -60,9 +64,21 @@ "local_files_only": { "type": "boolean" }, "legacy_cache_layout": { "type": "boolean" } }, - "required": [ - "name", - "repo_id" + "$comment": "repo_id is required if you want to actually download the model", + "oneOf": [ + { + "required": [ + "name", + "repo_id" + ] + }, + { + "required": [ + "name", + "local_model", + "local_tags" + ] + } ], "additionalProperties": false } diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 4e98a3b..4c44581 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -58,6 +58,7 @@ onnx_path = Path(shared.models_path, 'TaggerOnnx') os.makedirs(onnx_path, exist_ok=True) + class Interrogator: """ Interrogator class for tagger """ # the raw input and output. @@ -126,52 +127,11 @@ def refresh(cls) -> List[str]: if class_name[-12:] == "Interrogator": It_type = getattr(sys.modules[__name__], class_name) for name, obj in it.items(): + print(f"Loading {name} as {class_name}") if "name" not in obj: obj["name"] = name cls.entries[name] = It_type(**obj) - # load deepdanbooru project - for path in os.scandir(ddp_path): - if not path.is_dir(): - continue - print(f"Scanning {path} as deepdanbooru project") - - if not Path(path, 'project.json').is_file(): - print(f"Warning: {path} has no project.json, skipped") - continue - - cls.entries[path.name] = DeepDanbooruInterrogator(path.name, path) - # scan for onnx models as well - for path in os.scandir(onnx_path): - if not path.is_dir(): - continue - print(f"Scanning {path} as onnx model") - - onnx_files = [] - for file_name in os.scandir(path): - if file_name.name.endswith('.onnx'): - onnx_files.append(file_name) - - if len(onnx_files) != 1: - print(f"Warning: {path}: multiple .onnx models => skipped") - continue - - for csv in os.scandir(path): - if csv.name.endswith('.csv') and "tag" in csv.name.lower() \ - or "select" in csv.name.lower(): - tags_path = Path(path, csv.name) - break - else: - print(f"Warning: {path}: no selected tags .csv file, skipped") - continue - - if path.name not in cls.entries: - print(f"Warning: {path} not configured in interrogators.json") - continue - local_path = Path(path, onnx_files[0].name) - cls.entries[path.name].local_model = str(local_path) - cls.entries[path.name].local_tags = str(tags_path) - return sorted(i.name for i in cls.entries.values()) @staticmethod @@ -194,8 +154,8 @@ def __init__(self, name: str) -> None: # run_mode 0 is dry run, 1 means run (alternating), 2 means disabled self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2 # default path if not overridden by download - self.local_model = None - self.local_tags = None + self.local_model = '' + self.local_tags = '' # XXX don't Interrogator.refresh()-ception here def load(self) -> bool: @@ -453,8 +413,6 @@ def __init__( self.model_path = model_path self.tags_path = tags_path self.model = None - self.local_model = '' - self.local_tags = '' # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse # and pass only the supported args. @@ -471,10 +429,6 @@ def __init__( print(f"Warning: interrogators.json: model {self.name}: " f"parameter {k} unsupported or or wrong type.") - if 'repo_id' not in self.hf_params: - print(f"Warning: interrogators.json: HuggingFace model {self.name}" - " lacks a repo_id. If not already local, download may fail.") - attrs = getattr(shared.opts, 'tagger_hf_hub_down_opts', f'cache_dir="{Its.hf_cache}"') attrs = [attr.split('=') for attr in map(str.strip, attrs.split(','))] @@ -506,43 +460,22 @@ def __init__( "Invalid for hf_hub_download() => ignored.") def download(self) -> Tuple[str, str]: - repo_id = self.hf_params.get('repo_id', '(?)') - print(f"Loading {self.name} model file from {repo_id}") - if self.local_model == '': - Interrogator.refresh() paths = [self.local_model, self.local_tags] + try: + # To prevent download don't set repo_id, export HF_HUB_OFFLINE=1 + repo_id = self.hf_params['repo_id'] + + print(f"(Down)Loading {self.name} model file from {repo_id}") + for i, filename in enumerate([self.model_path, self.tags_path]): + self.hf_params['filename'] = filename + paths[i] = hf_hub_download(**self.hf_params) + except Exception as err: + print(f"Warning: {self.name}: {err} (might be as expected)") + for i in range(2): + if not os.path.isabs(paths[i]): + paths[i] = os.path.join(shared.models_path, paths[i]) + pass - data = {} - for k in self.repo_specs: - if k in self.hf_params: - data[k] = self.hf_params[k] - - # check if the model is up to date - info_path = Path(self.local_model).with_suffix('.info') - if info_path.exists(): - - if all(os.path.exists(p) for p in paths): - with open(info_path, 'r') as filen: - try: - old_data = json.load(filen) - if old_data == data: - print(f"Model {self.name} is up to date.") - return paths - except json.decoder.JSONDecodeError: - pass - - try: - for i, filen in enumerate([self.model_path, self.tags_path]): - self.hf_params['filename'] = filen - paths[i] = hf_hub_download(**self.hf_params) - except Exception as err: - print(f"hf_hub_download({self.hf_params}: {err}") - return paths - - # write the repo_specs to a json alongside the model so we can - # check if the model is up to date - with open(info_path, 'w') as filen: - json.dump(data, filen) return paths def load_model(self, model_path) -> None: From 6d68c9687bfafa900e8058bebf7229d98c507271 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 20 Sep 2023 00:04:45 +0200 Subject: [PATCH 75/78] add fromfile interrogator --- json_schema/interrogators_v1_schema.json | 18 +++++++++- tagger/interrogator.py | 43 ++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/json_schema/interrogators_v1_schema.json b/json_schema/interrogators_v1_schema.json index ea643b8..b193506 100644 --- a/json_schema/interrogators_v1_schema.json +++ b/json_schema/interrogators_v1_schema.json @@ -33,8 +33,24 @@ ".*": { "$ref": "#/$defs/huggingfaceInterrogator" } } }, + "FromFileInterrogator": { + "type": "object", + "patternProperties": { + ".*": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "val": { "type": "number", "default": 1.0 } + }, + "required": [ + "path" + ], + "additionalProperties": false + } + } + }, "additionalProperties": false - }, + }, "$defs": { "huggingfaceInterrogator": { "type": "object", diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 4c44581..0616634 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -485,6 +485,49 @@ def load_model(self, model_path) -> None: print(f'Loaded {self.name} model from {model_path}') +class FromFileInterrogator(Interrogator): + """ Pseudo Interrogator reading preinterrogated tags files """ + def __init__(self, name: str, path: os.PathLike, val=1.0) -> None: + super().__init__(name) + self.path = path + self.val = val + self.tags = None + + def load(self) -> None: + print(f'Loading {self.name} from {str(self.path)}') + # self.path is a directory + if not os.path.isdir(self.path): + raise ValueError(f'{self.path} is not a directory') + else: + self.tags = {} + for f in os.listdir(self.path): + self.tags[f] = {} + self.load_file(f) + + def load_file(self, tags_file: str) -> None: + image_name = str(tags_file).split('/')[-1].split('.')[0] + with open(tags_file, 'r') as f: + for line in f: + for x in map(str.split, line.split(',')): + if x[0] == '(' and x[-1] == ')' and ':' in x: + tag, val = x[1:-1].split(':') + self.tags[image_name][tag] = float(val) + else: + self.tags[image_name][x] = self.val + + def unload(self) -> None: + self.tags = {} + + def interrogate( + self, + image: Image + ) -> Tuple[ + Dict[str, float], # rating confidences + Dict[str, float] # tag confidences + ]: + return {}, self.tags[image.filename] + + class WaifuDiffusionInterrogator(HFInterrogator): """ Interrogator for Waifu Diffusion models """ def __init__( From 486edca5963b8b6b11c9347edb166d3bfcf9e842 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 20 Sep 2023 01:59:37 +0200 Subject: [PATCH 76/78] fix fromfileinterrogator --- default/interrogators.json | 7 ++++ json_schema/interrogators_v1_schema.json | 3 +- tagger/interrogator.py | 46 +++++++++++++++--------- tagger/uiset.py | 18 +++++----- 4 files changed, 48 insertions(+), 26 deletions(-) diff --git a/default/interrogators.json b/default/interrogators.json index 2787b6f..3b10889 100644 --- a/default/interrogators.json +++ b/default/interrogators.json @@ -49,6 +49,13 @@ "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip" } }, + "FromFileInterrogator": { + "[name].[hash:sha1].[output_extension]": { + "format": "[name].[hash:sha1].[output_extension]", + "path": "", + "value": 1.0 + } + }, "MLDanbooruInterrogator": { "mld-caformer.dec-5-97527" : { "model_path" : "ml_caformer_m36_dec-5-97527.onnx", diff --git a/json_schema/interrogators_v1_schema.json b/json_schema/interrogators_v1_schema.json index b193506..35fde69 100644 --- a/json_schema/interrogators_v1_schema.json +++ b/json_schema/interrogators_v1_schema.json @@ -40,7 +40,8 @@ "type": "object", "properties": { "path": { "type": "string" }, - "val": { "type": "number", "default": 1.0 } + "format": { "type": "string", "default": "[name].[output_extension]" }, + "value": { "type": "number", "default": 1.0 } }, "required": [ "path" diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 0616634..9da3f28 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -19,7 +19,8 @@ from modules import shared from tagger import settings # pylint: disable=import-error -from tagger.uiset import QData, IOData # pylint: disable=import-error +from tagger.uiset import QData, IOData, supported_extensions, \ + format_output_filename from preload import root_dir from . import dbimutils # pylint: disable=import-error # noqa @@ -487,33 +488,37 @@ def load_model(self, model_path) -> None: class FromFileInterrogator(Interrogator): """ Pseudo Interrogator reading preinterrogated tags files """ - def __init__(self, name: str, path: os.PathLike, val=1.0) -> None: + def __init__( + self, name: str, path, format='[name].[output_extension]', value=1.0 + ) -> None: super().__init__(name) - self.path = path - self.val = val - self.tags = None + self.path = Path(self.path) + self.val = value + self.format = format + if format == Its.output_filename_format and path == '': + raise ValueError(f"tagsfiles will ne overwritten with {format}") + self.tags = {} def load(self) -> None: print(f'Loading {self.name} from {str(self.path)}') + if self.path == '': + return + # self.path is a directory - if not os.path.isdir(self.path): + if not os.path.isdir(Path(self.path)): raise ValueError(f'{self.path} is not a directory') - else: - self.tags = {} - for f in os.listdir(self.path): - self.tags[f] = {} - self.load_file(f) def load_file(self, tags_file: str) -> None: - image_name = str(tags_file).split('/')[-1].split('.')[0] + basename = '.'.join(str(tags_file).split('/')[-1].split('.')[:-1]) + self.tags[basename] = {} with open(tags_file, 'r') as f: for line in f: - for x in map(str.split, line.split(',')): + for x in map(str.strip, line.split(',')): if x[0] == '(' and x[-1] == ')' and ':' in x: tag, val = x[1:-1].split(':') - self.tags[image_name][tag] = float(val) + self.tags[basename][tag] = float(val) else: - self.tags[image_name][x] = self.val + self.tags[basename][x] = self.val def unload(self) -> None: self.tags = {} @@ -525,7 +530,16 @@ def interrogate( Dict[str, float], # rating confidences Dict[str, float] # tag confidences ]: - return {}, self.tags[image.filename] + basename = '.'.join(image.filename.split('/')[-1].split('.')[:-1]) + path = Path(image.filename) + tags_filename = format_output_filename(path, self.format) + if self.path == '': + dir = path.parent + else: + dir = self.path + tags_filename = os.path.join(dir, tags_filename) + self.load_file(tags_filename) + return {}, self.tags[basename] class WaifuDiffusionInterrogator(HFInterrogator): diff --git a/tagger/uiset.py b/tagger/uiset.py index a7147fa..ae21b09 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -37,6 +37,12 @@ ] +def format_output_filename(path: Path, format='[name].[output_extension]') -> str: + info = tags_format.Info(path, 'txt') + fmt = partial(lambda info, m: tags_format.parse(m, info), info) + return tags_format.pattern.sub(fmt, format) + + class IOData: """ data class for input and output paths """ last_path_mtimes = None @@ -119,7 +125,7 @@ def update_input_glob(cls, input_glob: str) -> None: path_mtimes = [] for filename in glob(input_glob, recursive=recursive): if not os.path.isdir(filename): - ext = os.path.splitext(filename)[1].lower() + ext = os.path.splitext(filename)[-1].lower() if ext in supported_extensions: path_mtimes.append(os.path.getmtime(filename)) paths.append(filename) @@ -163,16 +169,11 @@ def set_batch_io(cls, paths: List[str]) -> None: base_dir_last_idx = path.parts.index(cls.base_dir_last) # format output filename - info = tags_format.Info(path, 'txt') - fmt = partial(lambda info, m: tags_format.parse(m, info), info) - msg = 'Invalid output format' cls.err.discard(msg) try: - formatted_output_filename = tags_format.pattern.sub( - fmt, - Its.output_filename_format - ) + formatted_output_filename = format_output_filename( + path, format=Its.output_filename_format) except (TypeError, ValueError): cls.err.add(msg) @@ -483,7 +484,6 @@ def correct_tag(cls, tag: str) -> str: if re_match(regex, tag): tag = re_sub(regex, cls.replace_tags[i], tag) break - return tag @classmethod From ac32226372bb2d07c51806aaca5edf07852c669b Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 20 Sep 2023 19:01:56 +0200 Subject: [PATCH 77/78] fix --- tagger/interrogator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 9da3f28..a8aca86 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -492,7 +492,7 @@ def __init__( self, name: str, path, format='[name].[output_extension]', value=1.0 ) -> None: super().__init__(name) - self.path = Path(self.path) + self.path = Path(path) self.val = value self.format = format if format == Its.output_filename_format and path == '': From 6262435d2bcd61289eaaf5c393cdbe35e244f478 Mon Sep 17 00:00:00 2001 From: Roel Kluin Date: Wed, 20 Sep 2023 22:17:10 +0200 Subject: [PATCH 78/78] The intent was to allow editing the interrogator properties in settings using json schema and json entries. But I'd like to --- tagger/interrogator.py | 8 ++------ tagger/settings.py | 19 ++++++++++++++++++- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index a8aca86..4b8db25 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -119,12 +119,8 @@ def refresh(cls) -> List[str]: if not it_path.exists(): raise FileNotFoundError(f'{it_path} not found.') - raw = json.loads(it_path.read_text()) - schema = root_dir.joinpath('json_schema', - 'interrogators_v1_schema.json') - validate(raw, json.loads(schema.read_text())) - - for class_name, it in raw.items(): + entries = settings.load_interrogator_and_schema()[0] + for class_name, it in entries.items(): if class_name[-12:] == "Interrogator": It_type = getattr(sys.modules[__name__], class_name) for name, obj in it.items(): diff --git a/tagger/settings.py b/tagger/settings.py index 979e880..d737b15 100644 --- a/tagger/settings.py +++ b/tagger/settings.py @@ -1,9 +1,12 @@ """Settings tab entries for the tagger module""" import os -from typing import List +from typing import List, Tuple from modules import shared # pylint: disable=import-error import gradio as gr from huggingface_hub import hf_hub_download +from preload import root_dir +from jsonschema import validate +import json # kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400! DEFAULT_KAMOJIS = '0_0, (o)_(o), +_+, +_-, ._., _, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||' # pylint: disable=line-too-long # noqa: E501 @@ -15,6 +18,20 @@ str(os.path.join(shared.models_path, 'interrogators'))) +def load_interrogator_and_schema() -> Tuple[List[str], List[str]]: + it_path = root_dir.joinpath("interrogators.json") + if not it_path.exists(): + it_path = root_dir.joinpath("default/interrogators.json") + if not it_path.exists(): + raise FileNotFoundError(f'{it_path} not found.') + + entries = json.loads(it_path.read_text()) + schema = json.loads(root_dir.joinpath('json_schema', + 'interrogators_v1_schema.json').read_text()) + validate(entries, schema) + return entries, schema + + def slider_wrapper(value, elem_id, **kwargs): # required or else gradio will throw errors return gr.Slider(**kwargs)