diff --git a/.gitignore b/.gitignore index fd6106c..8ea680e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,5 @@ __pycache__/ .vscode/ .venv/ .env - -presets/ \ No newline at end of file +presets/ +addons/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c3f6b2..c4c5000 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,73 @@ -# v1.1.1 + +Api changes: +Image interrogation via api receives two extra parameters; empty strings by +default. `queue`: the name for a queue, which could be e.g. the person or +subject name. You can leave it empty for the first interrogation, then the +response will que in a new auto-generated unique name, listed in the response. +# v1.2.0 (2023-09-16) + +Make sure you use this same name as queue, for all interrogations that you want +to be grouped together. The second parameter is `name_in_queue`: the name for +that particular image that is being queued, e.g. a file name. + +If both queue and name are empty, there is a single interrogation with response, +which includes nested objects "ratings" and "tags", so: +`{"ratings": {"sensitive": 0.5, ..}, "tags": {"tag1": 0.5, ..}}` + +If neither name nor queue are empty, the interrogation is queued under that name. +If already in queue, that name is changed - clobbered - with #. An exception is if +the given name is in which case an image checksum will be used instead of +a name. Duplicates are ignored. + +During queuing, the response is the number of all processed interrogations for all +active queues. + +If name_in_queue is empty, but queue is not, that particular queue is finalized, +A response is awaited for remaining interrogations in this queue (if any still). +The response, only for this queue, is an object with the name_in_queue as key, +and the tag with weights contained. Ratings have ther tag name prefixed with +"rating:". Example: +`{"name_in_queue": {"rating:sensitive": 0.5, "tag1": 0.5, ..}}` + +Fix in absence of tensrflow_io +Fix deprecation warning +Added three scripts in shell scripts under shell_scripts: + * A bash script to generate per safetensors file the fraction of images + that the model was trained on that was tagged with particular tokens. + * A python script to compare the interrogation results (read from db.json) + and find the top -c safetensors files that contain similar weights (or at + least, that was the intention, there may be better algorithms to compare, + but it seems to do the job). + * And finally a model_grep script which listts the tags and number of trained + images in a safetensors model. + +# v1.1.2 c9f8efd (2023-08-26) + +Explain recursive path usage better in ui +Fix sending tags via buttons to txt2img and img2img +type additions, inadvertently pushed, later retouched. +allow setting gpu device via flag +Fix inverted cumulative checkbox +wrap_gradio_gpu_call fallback +Fix for preload shared access +preload update +A few ui changes +Fix not clearing the tags after writing them to files +Fix: Tags were still added, beyond count threshold +fix search/replace bug +(here int based weights were reverted) +circumvent when unable to load tensorflow +fix for too many exclude_tags +add db.json validation schema, add schema validation +return fix for fastapi +pick up huggingface cache dir from env, with default, configurable also via settings. +leave tensorflow requirements to the user. +Fix for Reappearance of gradio bug: duplicate image edit +(index based weights, but later reverted) +Instead of cache_dir use local_dir, leav + + +# v1.1.1 eada050 (2023-07-20) Internal cleanup, no separate interrogation for inverse Fix issues with search and sending selection to keep/exclude @@ -11,7 +80,9 @@ fix some hf download issues fixes for fastapi added ML-Danbooru support, thanks to [CCRcmcpe](github.com/CCRcmcpe) -# v1.1.0 + +# v1.1.0 87706b7 (2023-07-16) + fix: failed to install onnxruntime package on MacOS thanks to heady713 fastapi: remote unload model, picked up from [here](https://github.com/toriato/stable-diffusion-webui-wd14-tagger/pull/109) attribute error fix from aria1th also reported by yjunej @@ -39,7 +110,7 @@ changed internal error handling, It is a bit quirky, which I intend to fix, stil If you find it keeps complaining about an input field without reason, just try editing that one again (e.g. add a space there and remove it). -# v1.0.0 +# v1.0.0 a1b59d6 (2023-07-10) You may have to remove the presets/default.json and save a new one.witth your desired defaults. Otherwise checkboxes may not have the right default values. diff --git a/default/interrogators.json b/default/interrogators.json new file mode 100644 index 0000000..3b10889 --- /dev/null +++ b/default/interrogators.json @@ -0,0 +1,71 @@ +{ + "MLDanbooruInterrogator": { + "mld-caformer.dec-5-97527" : { + "model_path" : "ml_caformer_m36_dec-5-97527.onnx", + "name" : "ML-Danbooru Caformer dec-5-97527", + "repo_id" : "deepghs/ml-danbooru-onnx" + }, + "mld-tresnetd.6-30000" : { + "model_path" : "TResnet-D-FLq_ema_6-30000.onnx", + "name" : "ML-Danbooru TResNet-D 6-30000", + "repo_id" : "deepghs/ml-danbooru-onnx" + } + }, + "WaifuDiffusionInterrogator": { + "wd-v1-4-moat-tagger.v2" : { + "name" : "WD14 moat tagger v2", + "repo_id" : "SmilingWolf/wd-v1-4-moat-tagger-v2" + }, + "wd14-convnext.v1" : { + "name" : "WD14 ConvNeXT v1", + "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger" + }, + "wd14-convnext.v2" : { + "name" : "WD14 ConvNeXT v2", + "repo_id" : "SmilingWolf/wd-v1-4-convnext-tagger-v2" + }, + "wd14-convnextv2.v1" : { + "name" : "WD14 ConvNeXTV2 v1", + "repo_id" : "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" + }, + "wd14-swinv2-v1" : { + "name" : "WD14 SwinV2 v1", + "repo_id" : "SmilingWolf/wd-v1-4-swinv2-tagger-v2" + }, + "wd14-vit.v1" : { + "name" : "WD14 ViT v1", + "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger" + }, + "wd14-vit.v2" : { + "name" : "WD14 ViT v2", + "repo_id" : "SmilingWolf/wd-v1-4-vit-tagger-v2" + } + }, + "DeepDanbooruInterrogator": { + "deepdanbooru-v3-20211112-sgd-e28": { + "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip" + }, + "deepdanbooru-v4-20200814-sgd-e30": { + "zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip" + } + }, + "FromFileInterrogator": { + "[name].[hash:sha1].[output_extension]": { + "format": "[name].[hash:sha1].[output_extension]", + "path": "", + "value": 1.0 + } + }, + "MLDanbooruInterrogator": { + "mld-caformer.dec-5-97527" : { + "model_path" : "ml_caformer_m36_dec-5-97527.onnx", + "name" : "ML-Danbooru Caformer dec-5-97527", + "repo_id" : "deepghs/ml-danbooru-onnx" + }, + "mld-tresnetd.6-30000" : { + "model_path" : "TResnet-D-FLq_ema_6-30000.onnx", + "name" : "ML-Danbooru TResNet-D 6-30000", + "repo_id" : "deepghs/ml-danbooru-onnx" + } + } +} diff --git a/install.py b/install.py index 37408bf..411df64 100644 --- a/install.py +++ b/install.py @@ -1,12 +1,13 @@ """Install requirements for WD14-tagger.""" import os import sys - +import json from launch import run # pylint: disable=import-error +local_dir = os.path.dirname(os.path.realpath(__file__)) + NAME = "WD14-tagger" -req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), - "requirements.txt") +req_file = os.path.join(local_dir, "requirements.txt") print(f"loading {NAME} reqs from {req_file}") run(f'"{sys.executable}" -m pip install -q -r "{req_file}"', f"Checking {NAME} requirements.", diff --git a/json_schema/db_json_v1_schema.json b/json_schema/db_json_v1_schema.json index 31e2696..c36aee4 100644 --- a/json_schema/db_json_v1_schema.json +++ b/json_schema/db_json_v1_schema.json @@ -10,28 +10,28 @@ "type": "array", "prefixItems": [ {"type": "string" }, - {"type": "number", "minimum": 0} - ], - "minContains": 2, - "maxContains": 2 - } - } - }, - "meta": { - "type": "object", - "properties": { - "index_shift": { - "type": "integer", - "minimum": 0, - "maximum": 16 + {"type": "number", "minimum": 0} + ], + "minContains": 2, + "maxContains": 2 } + } + }, + "meta": { + "type": "object", + "properties": { + "index_shift": { + "type": "integer", + "minimum": 0, + "maximum": 16 } - }, - "add": { "type": "string" }, - "exclude": { "type": "string" }, - "keep": { "type": "string" }, - "repl": { "type": "string" }, - "search": { "type": "string" } + } + }, + "add": { "type": "string" }, + "exclude": { "type": "string" }, + "keep": { "type": "string" }, + "repl": { "type": "string" }, + "search": { "type": "string" } }, "required": ["rating", "tag", "query"], "additionalProperties": false, diff --git a/json_schema/interrogators_v1_schema.json b/json_schema/interrogators_v1_schema.json new file mode 100644 index 0000000..35fde69 --- /dev/null +++ b/json_schema/interrogators_v1_schema.json @@ -0,0 +1,104 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Interrogator Configurations", + "type": "object", + "properties": { + "MLDanbooruInterrogator": { + "type": "object", + "patternProperties": { + ".*": { "$ref": "#/$defs/huggingfaceInterrogator" } + } + }, + "WaifuDiffusionInterrogator": { + "type": "object", + "patternProperties": { + ".*": { "$ref": "#/$defs/huggingfaceInterrogator" } + } + }, + "DeepDanbooruInterrogator": { + "type": "object", + "patternProperties": { + ".*": { + "type": "object", + "properties": { + "zip": { "type": "string" } + }, + "additionalProperties": false + } + } + }, + "MLDanbooruInterrogator": { + "type": "object", + "patternProperties": { + ".*": { "$ref": "#/$defs/huggingfaceInterrogator" } + } + }, + "FromFileInterrogator": { + "type": "object", + "patternProperties": { + ".*": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "format": { "type": "string", "default": "[name].[output_extension]" }, + "value": { "type": "number", "default": 1.0 } + }, + "required": [ + "path" + ], + "additionalProperties": false + } + } + }, + "additionalProperties": false + }, + "$defs": { + "huggingfaceInterrogator": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "local_model": { "type": "string" }, + "local_tags": { "type": "string" }, + "model_path": { "type": "string" }, + "repo_id": { "type": "string" }, + "filename": { "type": "string" }, + "subfolder": { "type": "string" }, + "repo_type": { "type": "string" }, + "revision": { "type": "string" }, + "endpoint": { "type": "string" }, + "library_name": { "type": "string" }, + "library_version": { "type": "string" }, + "cache_dir": { "type": "string" }, + "local_dir": { "type": "string" }, + "local_dir_use_symlinks": { "type": "string" }, + "user_agent": { "type": "string" }, + "force_download": { "type": "boolean" }, + "force_filename": { "type": "string" }, + "proxies": { "type": "string" }, + "etag_timeout": { "type": "number" }, + "resume_download": { "type": "boolean" }, + "token": { "type": "string" }, + "local_files_only": { "type": "boolean" }, + "legacy_cache_layout": { "type": "boolean" } + }, + "$comment": "repo_id is required if you want to actually download the model", + "oneOf": [ + { + "required": [ + "name", + "repo_id" + ] + }, + { + "required": [ + "name", + "local_model", + "local_tags" + ] + } + ], + "additionalProperties": false + } + } +} + diff --git a/preload.py b/preload.py index b16969a..6be1b60 100644 --- a/preload.py +++ b/preload.py @@ -1,6 +1,8 @@ """ Preload module for DeepDanbooru or onnxtagger. """ -from pathlib import Path from argparse import ArgumentParser +from pathlib import Path + +root_dir = Path(__file__).parent def preload(parser: ArgumentParser): @@ -19,8 +21,9 @@ def preload(parser: ArgumentParser): type=str, help='Path to directory with Onnyx project(s).' ) + # TODO allow using devices in parallel, specified as comma separed list parser.add_argument( '--additional-device-ids', type=str, - help='Extra device ID to use. cpu:0,gpu:1..', + help='Device ID to use. cpu:0, gpu:0 or gpu:1, etc.', ) diff --git a/shell_scripts/compare_weighted_frequencies.py b/shell_scripts/compare_weighted_frequencies.py new file mode 100644 index 0000000..a7c8496 --- /dev/null +++ b/shell_scripts/compare_weighted_frequencies.py @@ -0,0 +1,132 @@ +from typing import Dict +from math import ceil +from json import load +import argparse +import re +from collections import defaultdict + +# read two json files, compare the weighted frequencies of the tags in the two +# files the first file is json and contains all safetensor files, with major +# sections and weighted tags + +# the second file is the result of an interrogations of images, with weighted +# tags. tags may be substrings of the tags in the first file + +# next argument is an interrogator id, a string +# optinally there are comma delimited images as arguments + +# in the end print the top ten safetensors and major sections that containss +# the tags that are most similar to the tags in the second file + +# all weights are between 0 and 1, higher is more important + + +# example usage: +# first run shell_scripts/create_safetensors_db.sh +# then interrogate an image in a subdirectory test/ + + +# cd stable-diffusion-webui/extensions/stable-diffusion-webui-wd14-tagger/ +# +# python shell_scripts/compare_weighted_frequencies.py safetensors_db.json \ +# test/db.json + +# # .. lists used interrogation models + + +# python shell_scripts/compare_weighted_frequencies.py safetensors_db.json \ +# -c 20 test/db.json + + +desc = 'Compare weighted frequencies of tags in two json file' +parser = argparse.ArgumentParser(description=desc) +hlp = 'number of results to print' +parser.add_argument('-c', '--count', default=10, type=int, help=hlp) +parser.add_argument('file1', help='all safetensors json file') +parser.add_argument('file2', help='image interrogation json file') +parser.add_argument('id', help='interrogator id', nargs='?', default="") +parser.add_argument('images', nargs='*', help='images', default=[]) +args = parser.parse_args() + + +with open(args.file1) as f: + all_sftns = load(f) + +with open(args.file2) as f: + data = load(f) + +query = data["query"] + +indices = set() +if args.id == "": + uniq = set() + for k in data["query"]: + if k not in uniq: + uniq.add(k[64:]) + if len(uniq) != 1: + print("Missing interrogator id, contained are:") + for k in uniq: + print(k) + exit(1) + else: + # use the only one + args.id = uniq.pop() + +for k, t in data["query"].items(): + img_fn, idx = t + if k[64:] == args.id: + if len(args.images) > 0: + for i in args.images: + if img_fn[-len(i):] == i: + break + else: + continue + indices.add(int(idx)) + +interrogation_result = {} +for t, lst in data["tag"].items(): + wt = 0.0 + for stored in lst: + i = ceil(stored) - 1 + if i in indices: + wt += stored - i + if wt > 0.0: + interrogation_result[t] = wt / len(indices) + +scores: Dict[str, float] = defaultdict(float) + +for safetensor in all_sftns: + for major in all_sftns[safetensor]: + ct = len(all_sftns[safetensor][major]) + if ct == 0: + continue + + for tag, wt in interrogation_result.items(): + if tag in all_sftns[safetensor][major]: + sftns_wt = all_sftns[safetensor][major][tag] + n = (1.0 - abs(sftns_wt - wt)) + scores[safetensor + "\t" + major] += n / ct + else: + rex = re.compile(r'\b{}\b'.format(tag)) + t_len = len(tag) + # the tag may be a substring of a tag in the safetensor + # however only entire words are considered and a penalty if the + # string lenghts are close to each other + highest = 0.0 + for sftns_tag in all_sftns[safetensor][major]: + if rex.search(sftns_tag): + sftns_tag_len = len(sftns_tag) + sftns_wt = all_sftns[safetensor][major][sftns_tag] + n = (sftns_tag_len - t_len) / sftns_tag_len + n -= abs(sftns_wt - wt) + highest = max(highest, n) + scores[safetensor + "\t" + major] += highest / ct + +# sort the scores +sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) + +# print the top ten safetensors and major sections +for i in range(args.count): + if i >= len(sorted_scores): + break + print(sorted_scores[i][0] + "\t" + str(sorted_scores[i][1])) diff --git a/shell_scripts/create_safetensors_db.sh b/shell_scripts/create_safetensors_db.sh new file mode 100644 index 0000000..f2f0381 --- /dev/null +++ b/shell_scripts/create_safetensors_db.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# +# Create a database for safetensors wherein for the +# models each major tag, the occurrence frequency of +# each associated subtag is listed. +# +# requires https://github.com/by321/safetensors_util.git +# gnu parallel, jq, sed, awk +# + +# To build the safetensors_db.json database with +# "file.safetensors" { "major tag": { "tag1": , "tag2": .. } }: + + +# cd stable-diffusion-webui/extensions/stable-diffusion-webui-wd14-tagger/ +# git clone https://github.com/by321/safetensors_util.git +# +# bash shell_scripts/create_safetensors_db.sh -f -p ../../models/Lora -u safetensors_util/ -o safetensors_db.json +# +## now you can compare interrogation weights with the safetensors_db.json using +## shell_scripts/compare_weighted_frequencies.py, see there for usage. + + +# number of cpus to use by default or use -j to specify +ncpu=$(nproc --all) +[ $ncpu -gt 8 ] && ncpu=8 + +path=. +utilpath=. +force=0 +out=safetensors_db.json + +while [ $# -gt 0 ]; do + case "$1" in + -h|--help) + echo "Usage: $0 [-j ncpu] [-p path] [-u utilpath] [-f] [-o out]" + echo " -j ncpu number of cpus to use (default: $ncpu)" + echo " -p path path to directory containing safetensor models (default: ./)" + echo " -u utilpath path to safetensors_util.py" + echo " -f force overwrite of output file" + echo " -o out output file (default: safetensors_db.json)" + exit 0 + ;; + -j) ncpu="$2"; shift 2;; + -p) path="$2/"; shift 2;; + -u) utilpath="$2/"; shift 2;; + -f) force=1; shift 1;; + -o) out="$2"; shift 2;; + esac +done + +if [ ! -d "${path}" ]; then + echo "Error: '${path}' does not exist (use -p to specify path)" + exit 1 +fi + +if [ ! -e "${utilpath}/safetensors_util.py" ]; then + echo "Error: ${utilpath}/safetensors_util.py does not exist (use -u to specify path)" + exit 1 +fi + +if [ -e "${out}" -a $force -eq 0 ]; then + echo "Error: ${out} already exists (use -f to overwrite)" + exit 1 +fi + +ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | +sed -n '1b;p' | jq -r 'select(.__metadata__ != null) | .__metadata__ | .ss_tag_frequency | select( . != null )' 2>/dev/null | sed 's/\" /\"/' | +awk -v FS=': ' '{ + if (index(\$2, \"null\") > 0) next + o = index(\$0, \"{\") + if (o == 1) printf \"\\\"'{}'\\\": \" + if (o > 0) { + print \$0 + m = 0 + } else { + c = index(\$0, \"}\") + if (c > 0) { + L=\"\" + for (i in a) { + if (L != \"\") print \",\" + printf \"%s: %.6f\", i, a[i] / m + L = "x" + } + delete a + if (c == 1) print \$0\",\" + else print \"\n\"\$0 + } else { + x = index(\$2, \",\") + v = int(x != 0 ? substr(\$2, 1, x - 1) : \$2) + if (v > m) m = v + a[\$1] = v + } + } +}'" | sed -r ' +s/^/ /; +1s/^/{\n/; +s/\\"//g +s/^([ \t]+"[^"]+):*(: [01]+(\.[0-9]+)?,?)$/\1"\2/ +$s/,?$/\n}/ +' > "${out}" + + diff --git a/shell_scripts/model_grep.sh b/shell_scripts/model_grep.sh new file mode 100644 index 0000000..616b22d --- /dev/null +++ b/shell_scripts/model_grep.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# requires jq, grep +# +# Usage: ./model_grep.sh [-p path] [-u utilpath] + +# number of cpus to use by default or use -j to specify +ncpu=$(nproc --all) +[ $ncpu -gt 8 ] && ncpu=8 + +path=. +utilpath=. + +while [ $# -gt 0 ]; do + case "$1" in + -h|--help) + echo "Usage: $0 [ -j ncpu] [-p path] [-u utilpath] " + echo " -j ncpu number of cpus to use (default: $ncpu)" + echo " -p path path to directory containing safetensor models (default: ./)" + echo " -u utilpath path to safetensors_util.py" + echo " extended regex to match against model names" + exit 0 + ;; + -j) ncpu="$2"; shift 2;; + -p) path="$2"; shift 2;; + -u) utilpath="$2"; shift 2;; + esac +done + +if [ ! -d "${path}" ]; then + echo "Error: ${path} does not exist (use -p to specify path)" + exit 1 +fi + +if [ ! -e "${utilpath}/safetensors_util.py" ]; then + echo "Error: ${utilpath}/safetensors_util.py does not exist (use -u to specify path)" + exit 1 +fi + +if [ -n "$1" ]; then + ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | + sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -o -E '\"[^\"]*${1}[^\"]*\": [0-9]+'| sed 's~^~'{}':~p'" +else + tmp=$(mktemp) + sed 's/^/"[^\"]*/;s/$/[^\"]*": [0-9]+/' < /dev/stdin > $tmp + ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | + sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -oE -f $tmp | sed 's~^~'{}':~p'" + echo rm $tmp +fi + + diff --git a/tag_based_image_dedup.sh b/shell_scripts/tag_based_image_dedup.sh similarity index 100% rename from tag_based_image_dedup.sh rename to shell_scripts/tag_based_image_dedup.sh diff --git a/tagger/api.py b/tagger/api.py index c50de70..74ee0d2 100644 --- a/tagger/api.py +++ b/tagger/api.py @@ -1,7 +1,12 @@ """API module for FastAPI""" -from typing import Callable +from typing import Callable, Dict, Optional from threading import Lock from secrets import compare_digest +import asyncio +from collections import defaultdict +from hashlib import sha256 +import string +from random import choices from modules import shared # pylint: disable=import-error from modules.api.api import decode_base64_to_image # pylint: disable=E0401 @@ -9,15 +14,15 @@ from fastapi import FastAPI, Depends, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials -from tagger import utils # pylint: disable=import-error from tagger import api_models as models # pylint: disable=import-error from tagger.uiset import QData # pylint: disable=import-error +from tagger.interrogator import Interrogator class Api: """Api class for FastAPI""" def __init__( - self, app: FastAPI, qlock: Lock, prefix: str = None + self, app: FastAPI, qlock: Lock, prefix: Optional[str] = None ) -> None: if shared.cmd_opts.api_auth: self.credentials = {} @@ -26,8 +31,16 @@ def __init__( self.credentials[user] = password self.app = app + self.queue: Dict[str, asyncio.Queue] = {} + self.res: Dict[str, Dict[str, Dict[str, float]]] = \ + defaultdict(dict) self.queue_lock = qlock + self.tasks: Dict[str, asyncio.Task] = {} + + self.runner: Optional[asyncio.Task] = None self.prefix = prefix + self.running_batches: Dict[str, Dict[str, float]] = \ + defaultdict(lambda: defaultdict(int)) self.add_api_route( 'interrogate', @@ -40,17 +53,85 @@ def __init__( 'interrogators', self.endpoint_interrogators, methods=['GET'], - response_model=models.InterrogatorsResponse + response_model=models.TaggerInterrogatorsResponse ) self.add_api_route( - "unload-interrogators", + 'unload-interrogators', self.endpoint_unload_interrogators, - methods=["POST"], + methods=['POST'], response_model=str, ) - def auth(self, creds: HTTPBasicCredentials = None): + async def add_to_queue(self, m, q, n='', i=None, t=0.0) -> Dict[ + str, Dict[str, float] + ]: + if m not in self.queue: + self.queue[m] = asyncio.Queue() + # loop = asyncio.get_running_loop() + # asyncio.run_coroutine_threadsafe( + task = asyncio.create_task(self.queue[m].put((q, n, i, t))) + # , loop) + + if self.runner is None: + loop = asyncio.get_running_loop() + asyncio.ensure_future(self.batch_process(), loop=loop) + await task + return await self.tasks[q+"\t"+n] + + async def do_queued_interrogation(self, m, q, n, i, t) -> Dict[ + str, Dict[str, float] + ]: + self.running_batches[m][q] += 1.0 + # queue and name empty to process, not queue + res = self.endpoint_interrogate( + models.TaggerInterrogateRequest( + image=i, + model=m, + threshold=t, + name_in_queue='', + queue='' + ) + ) + self.res[q][n] = res.caption["tag"] + for k, v in res.caption["rating"].items(): + self.res[q][n]["rating:"+k] = v + return self.running_batches + + async def finish_queue(self, m, q) -> Dict[str, Dict[str, float]]: + if q in self.running_batches[m]: + del self.running_batches[m][q] + if q in self.res: + return self.res.pop(q) + return self.running_batches + + async def batch_process(self) -> None: + # loop = asyncio.get_running_loop() + while len(self.queue) > 0: + for m in self.queue: + # if zero the queue might just be pending + while True: + try: + # q, n, i, t = asyncio.run_coroutine_threadsafe( + # self.queue[m].get_nowait(), loop).result() + q, n, i, t = self.queue[m].get_nowait() + except asyncio.QueueEmpty: + break + self.tasks[q+"\t"+n] = asyncio.create_task( + self.do_queued_interrogation(m, q, n, i, t) if n != "" + else self.finish_queue(m, q) + ) + + for model in self.running_batches: + if len(self.running_batches[model]) == 0: + del self.queue[model] + else: + await asyncio.sleep(0.1) + + self.running_batches.clear() + self.runner = None + + def auth(self, creds: Optional[HTTPBasicCredentials] = None): if creds is None: creds = Depends(HTTPBasic()) if creds.username in self.credentials: @@ -74,41 +155,72 @@ def add_api_route(self, path: str, endpoint: Callable, **kwargs): Depends(self.auth)], **kwargs) return self.app.add_api_route(path, endpoint, **kwargs) + async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ + str, Dict[str, float] + ]: + """ queue an interrogation, or add to batch """ + if n == '': + task = asyncio.create_task(self.add_to_queue(m, q)) + else: + if n == '': + n = sha256(i).hexdigest() + if n in self.res[q]: + return self.running_batches + elif n in self.res[q]: + # clobber name if it's already in the queue + j = 0 + while f'{n}#{j}' in self.res[q]: + j += 1 + n = f'{n}#{j}' + self.res[q][n] = {} + # add image to queue + task = asyncio.create_task(self.add_to_queue(m, q, n, i, t)) + return await task + def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): + """ one file interrogation, queueing, or batch results """ if req.image is None: raise HTTPException(404, 'Image not found') - if req.model not in utils.interrogators.keys(): + if req.model not in Interrogator.entries.keys(): raise HTTPException(404, 'Model not found') - image = decode_base64_to_image(req.image) - interrogator = utils.interrogators[req.model] - - with self.queue_lock: - QData.tags.clear() - QData.ratings.clear() - QData.in_db.clear() - QData.for_tags_file.clear() - data = ('', '', '') + interrogator.interrogate(image) - QData.apply_filters(data) - output = QData.finalize(1) - - return models.TaggerInterrogateResponse( - caption={ - **output[0], - **output[1], - **output[2], - }) + m, q, n = (req.model, req.queue, req.name_in_queue) + res: Dict[str, Dict[str, float]] = {} + + if q != '' or n != '': + if q == '': + # generate a random queue name, not in use + while True: + q = ''.join(choices(string.ascii_uppercase + + string.digits, k=8)) + if q not in self.queue: + break + print(f'WD14 tagger api generated queue name: {q}') + res = asyncio.run(self.queue_interrogation(m, q, n, req.image, + req.threshold), debug=True) + else: + image = decode_base64_to_image(req.image) + interrogator = Interrogator.entries[m] + res = {"tag": {}, "rating": {}} + with self.queue_lock: + res["rating"], tag = interrogator.interrogate(image) + + for k, v in tag.items(): + if v > req.threshold: + res["tag"][k] = v + + return models.TaggerInterrogateResponse(caption=res) def endpoint_interrogators(self): - return models.InterrogatorsResponse( - models=list(utils.interrogators.keys()) + return models.TaggerInterrogatorsResponse( + models=list(Interrogator.entries.keys()) ) def endpoint_unload_interrogators(self): unloaded_models = 0 - for i in utils.interrogators.values(): + for i in Interrogator.entries.values(): if i.unload(): unloaded_models = unloaded_models + 1 diff --git a/tagger/api_models.py b/tagger/api_models.py index 7f2cd20..1cd2dde 100644 --- a/tagger/api_models.py +++ b/tagger/api_models.py @@ -9,27 +9,35 @@ class TaggerInterrogateRequest(sd_models.InterrogateRequest): """Interrogate request model""" model: str = Field( title='Model', - description='The interrogate model used.' + description='The interrogate model used.', ) - threshold: float = Field( - default=0.35, title='Threshold', - description='', - ge=0, - le=1 + description='The threshold used for the interrogate model.', + default=0.0, + ) + queue: str = Field( + title='Queue', + description='name of queue; leave empty for single response', + default='', + ) + name_in_queue: str = Field( + title='Name', + description='name to queue image as or use . leave empty to ' + 'retrieve the final response', + default='', ) class TaggerInterrogateResponse(BaseModel): """Interrogate response model""" - caption: Dict[str, float] = Field( + caption: Dict[str, Dict[str, float]] = Field( title='Caption', - description='The generated caption for the image.' + description='The generated captions for the image.' ) -class InterrogatorsResponse(BaseModel): +class TaggerInterrogatorsResponse(BaseModel): """Interrogators response model""" models: List[str] = Field( title='Models', diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 9e56a2d..4b8db25 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -4,8 +4,12 @@ import io import json import inspect +import sys +from shutil import unpack_archive +from requests import get as http_get from re import match as re_match -from platform import system, uname +from jsonschema import validate +from platform import uname from typing import Tuple, List, Dict, Callable from pandas import read_csv from PIL import Image, UnidentifiedImageError @@ -13,10 +17,11 @@ from tqdm import tqdm from huggingface_hub import hf_hub_download -from modules.paths import extensions_dir from modules import shared from tagger import settings # pylint: disable=import-error -from tagger.uiset import QData, IOData # pylint: disable=import-error +from tagger.uiset import QData, IOData, supported_extensions, \ + format_output_filename +from preload import root_dir from . import dbimutils # pylint: disable=import-error # noqa Its = settings.InterrogatorSettings @@ -29,13 +34,13 @@ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958 onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] -if shared.cmd_opts.device_id is not None: - m = re_match(r'([cg])pu:\d+$', shared.cmd_opts.device_id) +if shared.cmd_opts.additional_device_ids is not None: + m = re_match(r'([cg])pu:\d+$', shared.cmd_opts.additional_device_ids) if m is None: raise ValueError('--device-id is not cpu: or gpu:') if m.group(1) == 'c': onnxrt_providers.pop(0) - TF_DEVICE_NAME = f'/{shared.cmd_opts.device_id}' + TF_DEVICE_NAME = f'/{shared.cmd_opts.additional_device_ids}' elif use_cpu: TF_DEVICE_NAME = '/cpu:0' onnxrt_providers.pop(0) @@ -44,6 +49,16 @@ print(f'== WD14 tagger {TF_DEVICE_NAME}, {uname()} ==') +ddp_path = shared.cmd_opts.deepdanbooru_projects_path +if ddp_path is None: + ddp_path = Path(shared.models_path, 'deepdanbooru') +os.makedirs(ddp_path, exist_ok=True) + +onnx_path = shared.cmd_opts.onnxtagger_path +if onnx_path is None: + onnx_path = Path(shared.models_path, 'TaggerOnnx') +os.makedirs(onnx_path, exist_ok=True) + class Interrogator: """ Interrogator class for tagger """ @@ -61,6 +76,7 @@ class Interrogator: } output = None odd_increment = 0 + entries = {} @classmethod def flip(cls, key): @@ -75,7 +91,7 @@ def get_errors() -> str: # write errors in html pointer list, every error in a
  • tag errors = IOData.error_msg() if len(QData.err) > 0: - errors += 'Fix to write correct output:
    • ' + \ + errors += 'Possible issues:
      • ' + \ '
      • '.join(QData.err) + '
      ' return errors @@ -93,6 +109,28 @@ def setter(val) -> Tuple[str, str]: return setter + @classmethod + def refresh(cls) -> List[str]: + """Refreshes the interrogator entries""" + if len(cls.entries) == 0: + it_path = root_dir.joinpath("interrogators.json") + if not it_path.exists(): + it_path = root_dir.joinpath("default/interrogators.json") + if not it_path.exists(): + raise FileNotFoundError(f'{it_path} not found.') + + entries = settings.load_interrogator_and_schema()[0] + for class_name, it in entries.items(): + if class_name[-12:] == "Interrogator": + It_type = getattr(sys.modules[__name__], class_name) + for name, obj in it.items(): + print(f"Loading {name} as {class_name}") + if "name" not in obj: + obj["name"] = name + cls.entries[name] = It_type(**obj) + + return sorted(i.name for i in cls.entries.values()) + @staticmethod def load_image(path: str) -> Image: try: @@ -112,8 +150,12 @@ def __init__(self, name: str) -> None: self.tags = None # run_mode 0 is dry run, 1 means run (alternating), 2 means disabled self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2 + # default path if not overridden by download + self.local_model = '' + self.local_tags = '' + # XXX don't Interrogator.refresh()-ception here - def load(self): + def load(self) -> bool: raise NotImplementedError() def large_batch_interrogate(self, images: List, dry_run=False) -> str: @@ -238,10 +280,31 @@ def interrogate( raise NotImplementedError() +def download_and_extract(url: str, path: Path, rm_zip=True) -> None: + """ Download and extract a zip file """ + response = http_get(url, stream=True) + local_zip = path / url.split('/')[-1] + with open(local_zip, 'wb') as f: + for data in tqdm(response.iter_content(), desc=f'Downloading {url}'): + f.write(data) + unpack_archive(local_zip, path) + if rm_zip: + os.remove(local_zip) + + class DeepDanbooruInterrogator(Interrogator): """ Interrogator for DeepDanbooru models """ - def __init__(self, name: str, project_path: os.PathLike) -> None: + def __init__(self, name: str, project_path=None, zip='') -> None: super().__init__(name) + + if project_path is None: + project_path = Path(ddp_path, name) + if not project_path.is_dir(): + if zip == '': + raise FileNotFoundError(f'{project_path} does not exist') + os.makedirs(project_path) + download_and_extract(zip, project_path) + self.project_path = project_path self.model = None self.tags = None @@ -299,7 +362,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} import deepdanbooru.data as ddd @@ -333,78 +397,166 @@ def large_batch_interrogate(self, images: List, dry_run=False) -> str: raise NotImplementedError() -# FIXME this is silly, in what scenario would the env change from MacOS to -# another OS? TODO: remove if the author does not respond. -def get_onnxrt(): - try: - import onnxruntime - return onnxruntime - except ImportError: - # only one of these packages should be installed at one time in an env - # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime - # TODO: remove old package when the environment changes? - from launch import is_installed, run_pip - if not is_installed('onnxruntime'): - if system() == "Darwin": - package_name = "onnxruntime-silicon" +class HFInterrogator(Interrogator): + """ Interrogator for HuggingFace models """ + def __init__( + self, + name: str, + model_path: str, + tags_path: str, + **kwargs, + ) -> None: + super().__init__(name) + self.model_path = model_path + self.tags_path = tags_path + self.model = None + # tagger_hf_hub_down_opts contains args to hf_hub_download(). Parse + # and pass only the supported args. + + signature = inspect.signature(hf_hub_download) + self.repo_specs = {'repo_id', 'revision', 'library_name', + 'library_version'} + self.hf_params = {} + for k in kwargs: + if k in signature.parameters: + tp = signature.parameters[k].annotation + if isinstance(kwargs[k], tp): + self.hf_params[k] = kwargs[k] + continue + print(f"Warning: interrogators.json: model {self.name}: " + f"parameter {k} unsupported or or wrong type.") + + attrs = getattr(shared.opts, 'tagger_hf_hub_down_opts', + f'cache_dir="{Its.hf_cache}"') + attrs = [attr.split('=') for attr in map(str.strip, attrs.split(','))] + + signature = inspect.signature(hf_hub_download) + for arg, val in attrs: + if arg == 'filename' or arg in self.repo_specs: + + print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " + "Specific options need to go in the interrogators.json.") + + elif arg in signature.parameters: + try: + tp = signature.parameters[arg].annotation + self.hf_params[arg] = tp(val) + + except TypeError: + # unions, used for str or PathLike and a few. + if val == 'None': + self.hf_params[arg] = None + elif arg == 'token' and val in {'True', 'False'}: + self.hf_params[arg] = val == 'True' + else: + if val[0] == val[-1] and val[0] in "'\"": + val = val[1:-1] + self.hf_params[arg] = str(val) else: - package_name = "onnxruntime-gpu" - package = os.environ.get( - 'ONNXRUNTIME_PACKAGE', - package_name - ) + print(f"Settings -> Tagger -> HuggingFace parameters: {arg}: " + "Invalid for hf_hub_download() => ignored.") + + def download(self) -> Tuple[str, str]: + paths = [self.local_model, self.local_tags] + try: + # To prevent download don't set repo_id, export HF_HUB_OFFLINE=1 + repo_id = self.hf_params['repo_id'] + + print(f"(Down)Loading {self.name} model file from {repo_id}") + for i, filename in enumerate([self.model_path, self.tags_path]): + self.hf_params['filename'] = filename + paths[i] = hf_hub_download(**self.hf_params) + except Exception as err: + print(f"Warning: {self.name}: {err} (might be as expected)") + for i in range(2): + if not os.path.isabs(paths[i]): + paths[i] = os.path.join(shared.models_path, paths[i]) + pass + + return paths + + def load_model(self, model_path) -> None: + import onnxruntime + self.model = onnxruntime.InferenceSession(model_path, + providers=onnxrt_providers) + print(f'Loaded {self.name} model from {model_path}') + + +class FromFileInterrogator(Interrogator): + """ Pseudo Interrogator reading preinterrogated tags files """ + def __init__( + self, name: str, path, format='[name].[output_extension]', value=1.0 + ) -> None: + super().__init__(name) + self.path = Path(path) + self.val = value + self.format = format + if format == Its.output_filename_format and path == '': + raise ValueError(f"tagsfiles will ne overwritten with {format}") + self.tags = {} - run_pip(f'install {package}', 'onnxruntime') + def load(self) -> None: + print(f'Loading {self.name} from {str(self.path)}') + if self.path == '': + return + + # self.path is a directory + if not os.path.isdir(Path(self.path)): + raise ValueError(f'{self.path} is not a directory') + + def load_file(self, tags_file: str) -> None: + basename = '.'.join(str(tags_file).split('/')[-1].split('.')[:-1]) + self.tags[basename] = {} + with open(tags_file, 'r') as f: + for line in f: + for x in map(str.strip, line.split(',')): + if x[0] == '(' and x[-1] == ')' and ':' in x: + tag, val = x[1:-1].split(':') + self.tags[basename][tag] = float(val) + else: + self.tags[basename][x] = self.val + + def unload(self) -> None: + self.tags = {} - import onnxruntime - return onnxruntime + def interrogate( + self, + image: Image + ) -> Tuple[ + Dict[str, float], # rating confidences + Dict[str, float] # tag confidences + ]: + basename = '.'.join(image.filename.split('/')[-1].split('.')[:-1]) + path = Path(image.filename) + tags_filename = format_output_filename(path, self.format) + if self.path == '': + dir = path.parent + else: + dir = self.path + tags_filename = os.path.join(dir, tags_filename) + self.load_file(tags_filename) + return {}, self.tags[basename] -class WaifuDiffusionInterrogator(Interrogator): +class WaifuDiffusionInterrogator(HFInterrogator): """ Interrogator for Waifu Diffusion models """ def __init__( self, name: str, model_path='model.onnx', tags_path='selected_tags.csv', - repo_id=None, - is_hf=True, + **kwargs, ) -> None: - super().__init__(name) - self.repo_id = repo_id - self.model_path = model_path - self.tags_path = tags_path - self.tags = None - self.model = None + super().__init__(name, model_path, tags_path, **kwargs) self.tags = None - self.local_model = None - self.local_tags = None - self.is_hf = is_hf - - def download(self) -> None: - mdir = Path(shared.models_path, 'interrogators') - if self.is_hf: - cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) - print(f"Loading {self.name} model file from {self.repo_id}, " - f"{self.model_path}") - - model_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.model_path, - cache_dir=cache) - tags_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.tags_path, - cache_dir=cache) - else: - model_path = self.local_model - tags_path = self.local_tags + def update_model_json(self, model_path, tags_path): download_model = { 'name': self.name, 'model_path': model_path, 'tags_path': tags_path, } + mdir = Path(shared.models_path, 'interrogators') mpath = Path(mdir, 'model.json') data = [download_model] @@ -425,16 +577,22 @@ def download(self) -> None: with io.open(mpath, 'w', encoding='utf-8') as filename: json.dump(data, filename) - return model_path, tags_path - def load(self) -> None: + def load(self) -> bool: model_path, tags_path = self.download() - ort = get_onnxrt() - self.model = ort.InferenceSession(model_path, - providers=onnxrt_providers) - print(f'Loaded {self.name} model from {self.repo_id}') + if not os.path.exists(model_path): + print(f'Model path {model_path} not found.') + return False + + if not os.path.exists(tags_path): + print(f'Tags path {tags_path} not found.') + return False + + self.load_model(model_path) + self.update_model_json(model_path, tags_path) self.tags = read_csv(tags_path) + return True def interrogate( self, @@ -445,7 +603,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} # code for converting the image and running the model is taken from the # link below. thanks, SmilingWolf! @@ -537,7 +696,8 @@ def large_batch_interrogate(self, images, dry_run=True) -> None: # init model if not hasattr(self, 'model') or self.model is None: - self.load() + if not self.load(): + return os.environ["TF_XLA_FLAGS"] = '--tf_xla_auto_jit=2 '\ '--tf_xla_cpu_global_jit' @@ -583,49 +743,36 @@ def pred_model(model): del os.environ["TF_XLA_FLAGS"] -class MLDanbooruInterrogator(Interrogator): +class MLDanbooruInterrogator(HFInterrogator): """ Interrogator for the MLDanbooru model. """ def __init__( self, name: str, - repo_id: str, model_path: str, tags_path='classes.json', + **kwargs ) -> None: - super().__init__(name) - self.model_path = model_path - self.tags_path = tags_path - self.repo_id = repo_id + super().__init__(name, model_path, tags_path, **kwargs) self.tags = None - self.model = None - def download(self) -> Tuple[str, str]: - print(f"Loading {self.name} model file from {self.repo_id}") - cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) + def load(self) -> bool: + model_path, tags_path = self.download() - model_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.model_path, - cache_dir=cache - ) - tags_path = hf_hub_download( - repo_id=self.repo_id, - filename=self.tags_path, - cache_dir=cache - ) - return model_path, tags_path + if not os.path.exists(model_path): + print(f'Model path {model_path} not found.') + return False - def load(self) -> None: - model_path, tags_path = self.download() + if not os.path.exists(tags_path): + print(f'Tags path {tags_path} not found.') + return False - ort = get_onnxrt() - self.model = ort.InferenceSession(model_path, - providers=onnxrt_providers) - print(f'Loaded {self.name} model from {model_path}') + self.load_model(model_path) with open(tags_path, 'r', encoding='utf-8') as filen: self.tags = json.load(filen) + return True + def interrogate( self, image: Image @@ -635,7 +782,8 @@ def interrogate( ]: # init model if self.model is None: - self.load() + if not self.load(): + return {}, {} image = dbimutils.fill_transparent(image) image = dbimutils.resize(image, 448) # TODO CUSTOMIZE diff --git a/tagger/preset.py b/tagger/preset.py index 9189535..7a132d9 100644 --- a/tagger/preset.py +++ b/tagger/preset.py @@ -106,3 +106,6 @@ def list(self) -> List[str]: presets.append(self.default_filename) return presets + + +preset = Preset(Path(__file__).parent.parent.joinpath('presets')) diff --git a/tagger/settings.py b/tagger/settings.py index 8510468..d737b15 100644 --- a/tagger/settings.py +++ b/tagger/settings.py @@ -1,16 +1,36 @@ """Settings tab entries for the tagger module""" import os -from typing import List +from typing import List, Tuple from modules import shared # pylint: disable=import-error -from gradio import inputs as gr +import gradio as gr +from huggingface_hub import hf_hub_download +from preload import root_dir +from jsonschema import validate +import json # kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400! DEFAULT_KAMOJIS = '0_0, (o)_(o), +_+, +_-, ._., _, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||' # pylint: disable=line-too-long # noqa: E501 DEFAULT_OFF = '[name].[output_extension]' -HF_CACHE = os.environ.get('HF_HOME', os.environ.get('HUGGINGFACE_HUB_CACHE', - str(os.path.join(shared.models_path, 'interrogators')))) +HF_CACHE = os.environ.get( + 'HUGGINGFACE_HUB_CACHE', # defaults to "$HF_HOME/hub" + str(os.path.join(shared.models_path, 'interrogators'))) + + +def load_interrogator_and_schema() -> Tuple[List[str], List[str]]: + it_path = root_dir.joinpath("interrogators.json") + if not it_path.exists(): + it_path = root_dir.joinpath("default/interrogators.json") + if not it_path.exists(): + raise FileNotFoundError(f'{it_path} not found.') + + entries = json.loads(it_path.read_text()) + schema = json.loads(root_dir.joinpath('json_schema', + 'interrogators_v1_schema.json').read_text()) + validate(entries, schema) + return entries, schema + def slider_wrapper(value, elem_id, **kwargs): # required or else gradio will throw errors @@ -97,7 +117,7 @@ def on_ui_settings(): key='tagger_repl_us_excl', info=shared.OptionInfo( DEFAULT_KAMOJIS, - label='Excudes (split by comma)', + label='Underscore replacement excludes (split by comma)', section=section, ), ) @@ -123,11 +143,11 @@ def on_ui_settings(): ) # see huggingface_hub guides/manage-cache shared.opts.add_option( - key='tagger_hf_cache_dir', + key='tagger_hf_hub_down_opts', info=shared.OptionInfo( - HF_CACHE, - label='HuggingFace cache directory, ' - 'see huggingface_hub guides/manage-cache', + str(f'cache_dir="{HF_CACHE}"'), + label='HuggingFace parameters, Comma delimited: arg=value, ' + 'see huggingface_hub docs for available or leave alone.', section=section, ), ) diff --git a/tagger/ui.py b/tagger/ui.py index 8102dc4..c958f6a 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -8,8 +8,7 @@ try: from tensorflow import __version__ as tf_version except ImportError: - def tf_version(): - return '0.0.0' + tf_version = '0.0.0' from html import escape as html_esc @@ -20,13 +19,14 @@ def tf_version(): from modules.call_queue import wrap_gradio_gpu_call except ImportError: from webui import wrap_gradio_gpu_call # pylint: disable=import-error -from tagger import utils # pylint: disable=import-error from tagger.interrogator import Interrogator as It # pylint: disable=E0401 from tagger.uiset import IOData, QData # pylint: disable=import-error +from tagger.preset import preset TAG_INPUTS = ["add", "keep", "exclude", "search", "replace"] COMMON_OUTPUT = Tuple[ Optional[str], # tags as string + Optional[str], # html tags as string Optional[str], # discarded tags as string Optional[Dict[str, float]], # rating confidences Optional[Dict[str, float]], # tag confidences @@ -35,11 +35,11 @@ def tf_version(): ] -def unload_interrogators() -> List[str]: +def unload_interrogators() -> Tuple[str]: unloaded_models = 0 remaining_models = '' - for i in utils.interrogators.values(): + for i in It.entries.values(): if i.unload(): unloaded_models = unloaded_models + 1 elif i.model is not None: @@ -52,7 +52,7 @@ def unload_interrogators() -> List[str]: "not be unloaded, a known issue." QData.clear(1) - return [f'{unloaded_models} model(s) unloaded{remaining_models}'] + return (f'{unloaded_models} model(s) unloaded{remaining_models}',) def on_interrogate( @@ -65,7 +65,7 @@ def on_interrogate( It.input["output_dir"] = output_dir if len(IOData.err) > 0: - return None, None, None, None, None, IOData.error_msg() + return (None,) * 6 + (IOData.error_msg(),) for i, val in enumerate(args): part = TAG_INPUTS[i] @@ -73,10 +73,10 @@ def on_interrogate( getattr(QData, "update_" + part)(val) It.input[part] = val - interrogator: It = next((i for i in utils.interrogators.values() if + interrogator: It = next((i for i in It.entries.values() if i.name == name), None) if interrogator is None: - return None, None, None, None, None, f"'{name}': invalid interrogator" + return (None,) * 6 + (f"'{name}': invalid interrogator",) interrogator.batch_interrogate() return search_filter(filt) @@ -90,7 +90,7 @@ def on_interrogate_image(*args) -> COMMON_OUTPUT: # hack brcause image interrogaion occurs twice It.odd_increment = It.odd_increment + 1 if It.odd_increment & 1 == 1: - return (None, None, None, None, None, '') + return (None,) * 6 + ('',) return on_interrogate_image_submit(*args) @@ -104,11 +104,11 @@ def on_interrogate_image_submit( It.input[part] = val if image is None: - return None, None, None, None, None, 'No image selected' - interrogator: It = next((i for i in utils.interrogators.values() if + return (None,) * 6 + ('No image selected',) + interrogator: It = next((i for i in It.entries.values() if i.name == name), None) if interrogator is None: - return None, None, None, None, None, f"'{name}': invalid interrogator" + return (None,) * 6 + (f"'{name}': invalid interrogator",) interrogator.interrogate_image(image) return search_filter(filt) @@ -155,18 +155,18 @@ def search_filter(filt: str) -> COMMON_OUTPUT: """ filters the tags and lost tags for the search field """ ratings, tags, lost, info = It.output if ratings is None: - return (None, None, None, None, None, info) + return (None,) * 6 + (info,) if filt: re_part = re.compile('(' + re.sub(', ?', '|', filt) + ')') tags = {k: v for k, v in tags.items() if re_part.search(k)} lost = {k: v for k, v in lost.items() if re_part.search(k)} - tags_str = ', '.join(f'{k}' for k, v in tags.items()) - lost_str = ', '.join(f'{k}' for k, v in lost.items()) + h_tags = ', '.join(f'{k}' for k in tags.keys()) + h_lost = ', '.join(f'{k}' for k in lost.keys()) - return (tags_str, lost_str, ratings, tags, lost, info) + return (', '.join(tags.keys()), h_tags, h_lost, ratings, tags, lost, info) def on_ui_tabs(): @@ -187,19 +187,20 @@ def on_ui_tabs(): interactive=True, type="pil" ) - image_submit = gr.Button( + img_submit = gr.Button( value='Interrogate image', variant='primary' ) with gr.TabItem(label='Batch from directory'): - input_glob = utils.preset.component( + input_glob = preset.component( gr.Textbox, value='', - label='Input directory - See also settings tab.', + label='Input directory - To recurse use ** or */* ' + 'in your glob; also check the settings tab.', placeholder='/path/to/images or to/images/**/*' ) - output_dir = utils.preset.component( + output_dir = preset.component( gr.Textbox, value=It.input["output_dir"], label='Output directory', @@ -213,7 +214,7 @@ def on_ui_tabs(): ) with gr.Row(variant='compact'): with gr.Column(variant='panel'): - large_query = utils.preset.component( + large_query = preset.component( gr.Checkbox, label='huge batch query (TF 2.10, ' 'experimental)', @@ -222,7 +223,7 @@ def on_ui_tabs(): version.parse('2.10') ) with gr.Column(variant='panel'): - save_tags = utils.preset.component( + save_tags = preset.component( gr.Checkbox, label='Save to tags files', value=True @@ -238,7 +239,7 @@ def on_ui_tabs(): with gr.Column(): # preset selector with gr.Row(variant='compact'): - available_presets = utils.preset.list() + available_presets = preset.list() selected_preset = gr.Dropdown( label='Preset', choices=available_presets, @@ -252,17 +253,13 @@ def on_ui_tabs(): ui.create_refresh_button( selected_preset, lambda: None, - lambda: {'choices': utils.preset.list()}, + lambda: {'choices': preset.list()}, 'refresh_preset' ) with gr.Row(variant='compact'): - def refresh(): - utils.refresh_interrogators() - return sorted(x.name for x in utils.interrogators - .values()) - interrogator_names = refresh() - interrogator = utils.preset.component( + interrogator_names = It.refresh() + interrogator = preset.component( gr.Dropdown, label='Interrogator', choices=interrogator_names, @@ -276,7 +273,7 @@ def refresh(): ui.create_refresh_button( interrogator, lambda: None, - lambda: {'choices': refresh()}, + lambda: {'choices': It.refresh()}, 'refresh_interrogator' ) @@ -284,20 +281,20 @@ def refresh(): value='Unload all interrogate models' ) with gr.Row(variant='compact'): - tag_input["add"] = utils.preset.component( + tag_input["add"] = preset.component( gr.Textbox, label='Additional tags (comma split)', elem_id='additional-tags' ) with gr.Row(variant='compact'): - threshold = utils.preset.component( + threshold = preset.component( gr.Slider, label='Weight threshold', minimum=0, maximum=1, value=QData.threshold ) - tag_frac_threshold = utils.preset.component( + tag_frac_threshold = preset.component( gr.Slider, label='Min tag fraction in batch and ' 'interrogations', @@ -306,34 +303,34 @@ def refresh(): value=QData.tag_frac_threshold, ) with gr.Row(variant='compact'): - cumulative = utils.preset.component( + cumulative = preset.component( gr.Checkbox, label='Combine interrogations', value=False ) - unload_after = utils.preset.component( + unload_after = preset.component( gr.Checkbox, label='Unload model after running', value=False ) with gr.Row(variant='compact'): - tag_input["search"] = utils.preset.component( + tag_input["search"] = preset.component( gr.Textbox, label='Search tag, .. ->', elem_id='search-tags' ) - tag_input["replace"] = utils.preset.component( + tag_input["replace"] = preset.component( gr.Textbox, label='-> Replace tag, ..', elem_id='replace-tags' ) with gr.Row(variant='compact'): - tag_input["keep"] = utils.preset.component( + tag_input["keep"] = preset.component( gr.Textbox, label='Keep tag, ..', elem_id='keep-tags' ) - tag_input["exclude"] = utils.preset.component( + tag_input["exclude"] = preset.component( gr.Textbox, label='Exclude tag, ..', elem_id='exclude-tags' @@ -352,7 +349,7 @@ def refresh(): variant='secondary' ) with gr.Column(variant='compact'): - tag_search_selection = utils.preset.component( + tag_search_selection = preset.component( gr.Textbox, label='Multi string search: part1, part2.. ' '(Enter key to update)', @@ -360,7 +357,8 @@ def refresh(): with gr.Tabs(): with gr.TabItem(label='Ratings and included tags'): # clickable tags to populate excluded tags - tags = gr.HTML( + tags = gr.State(value="") + html_tags = gr.HTML( label='Tags', elem_id='tags', ) @@ -411,11 +409,11 @@ def refresh(): save_tags.input(fn=IOData.flip_save_tags(), inputs=[], outputs=[]) # Preset and unload buttons - selected_preset.change(fn=utils.preset.apply, inputs=[selected_preset], - outputs=[*utils.preset.components, info]) + selected_preset.change(fn=preset.apply, inputs=[selected_preset], + outputs=[*preset.components, info]) - save_preset_button.click(fn=utils.preset.save, inputs=[selected_preset, - *utils.preset.components], outputs=[info]) + save_preset_button.click(fn=preset.save, inputs=[selected_preset, + *preset.components], outputs=[info]) unload_all_models.click(fn=unload_interrogators, outputs=[info]) @@ -443,7 +441,7 @@ def refresh(): tab_gallery.select(fn=on_gallery, inputs=[], outputs=[gallery]) - common_output = [tags, discarded_tags, rating_confidences, + common_output = [tags, html_tags, discarded_tags, rating_confidences, tag_confidences, excluded_tag_confidences, info] # search input textbox @@ -466,13 +464,11 @@ def refresh(): [tag_input[tag] for tag in TAG_INPUTS] # interrogation events - image_submit.click( - fn=wrap_gradio_gpu_call(on_interrogate_image_submit), - inputs=[image] + common_input, outputs=common_output) + img_submit.click(fn=wrap_gradio_gpu_call(on_interrogate_image_submit), + inputs=[image] + common_input, outputs=common_output) - image.change( - fn=wrap_gradio_gpu_call(on_interrogate_image), - inputs=[image] + common_input, outputs=common_output) + image.change(fn=wrap_gradio_gpu_call(on_interrogate_image), + inputs=[image] + common_input, outputs=common_output) batch_submit.click(fn=wrap_gradio_gpu_call(on_interrogate), inputs=[input_glob, output_dir] + common_input, diff --git a/tagger/uiset.py b/tagger/uiset.py index a3aa4ed..ae21b09 100644 --- a/tagger/uiset.py +++ b/tagger/uiset.py @@ -1,6 +1,6 @@ """ for handling ui settings """ -from typing import List, Dict, Tuple, Callable, Set, Optional, Pattern +from typing import List, Dict, Tuple, Callable, Set, Optional import os from pathlib import Path from glob import glob @@ -17,6 +17,7 @@ from modules.deepbooru import re_special # pylint: disable=import-error from tagger import format as tags_format # pylint: disable=import-error from tagger import settings # pylint: disable=import-error +from preload import root_dir Its = settings.InterrogatorSettings @@ -29,24 +30,28 @@ # interrogator return type ItRetTP = Tuple[ - Optional[Dict[str, float]], # rating confidences - Optional[Dict[str, float]], # tag confidences - Optional[Dict[str, float]], # excluded tag confidences - str, # error message + Dict[str, float], # rating confidences + Dict[str, float], # tag confidences + Dict[str, float], # excluded tag confidences + str, # error message ] +def format_output_filename(path: Path, format='[name].[output_extension]') -> str: + info = tags_format.Info(path, 'txt') + fmt = partial(lambda info, m: tags_format.parse(m, info), info) + return tags_format.pattern.sub(fmt, format) + + class IOData: """ data class for input and output paths """ last_path_mtimes = None base_dir = None output_root = None - paths: List[Tuple[ - Path, Optional[Path], Optional[Path], Optional[str] - ]] = [] + paths: List[List[str]] = [] save_tags = True err: Set[str] = set() - base_dir_last = None + base_dir_last: Optional[str] = None @classmethod def error_msg(cls) -> str: @@ -68,7 +73,7 @@ def update_output_dir(cls, output_dir: str) -> None: """ update output directory, and set input and output paths """ pout = Path(output_dir) if pout != cls.output_root: - paths = [str(x[0]) for x in cls.paths] + paths = [x[0] for x in cls.paths] cls.paths = [] cls.output_root = pout cls.set_batch_io(paths) @@ -84,13 +89,13 @@ def get_hashes(cls) -> Set[str]: """ get hashes of all files """ ret = set() for entries in cls.paths: - if entries[3] is not None: + if len(entries) == 4: ret.add(entries[3]) else: # if there is no checksum, calculate it - image = Image.open(Path(entries[0])) + image = Image.open(entries[0]) checksum = cls.get_bytes_hash(image.tobytes()) - entries = (entries[0], entries[1], entries[2], checksum) + entries.append(checksum) ret.add(checksum) return ret @@ -119,16 +124,17 @@ def update_input_glob(cls, input_glob: str) -> None: recursive = getattr(shared.opts, 'tagger_batch_recursive', True) path_mtimes = [] for filename in glob(input_glob, recursive=recursive): - ext = os.path.splitext(filename)[1].lower() - if ext in supported_extensions: - path_mtimes.append(os.path.getmtime(filename)) - paths.append(filename) - elif ext != '.txt' and 'db.json' not in filename: - print(f'{filename}: not an image extension: "{ext}"') + if not os.path.isdir(filename): + ext = os.path.splitext(filename)[-1].lower() + if ext in supported_extensions: + path_mtimes.append(os.path.getmtime(filename)) + paths.append(filename) + elif ext != '.txt' and 'db.json' not in filename: + print(f'{filename}: not an image extension: "{ext}"') # interrogating in a directory with no pics, still flush the cache if len(path_mtimes) > 0 and cls.last_path_mtimes == path_mtimes: - print('No changed images') + # No (changed) images, keep the data return QData.clear(2) @@ -153,84 +159,79 @@ def set_batch_io(cls, paths: List[str]) -> None: """ set input and output paths for batch mode """ checked_dirs = set() cls.paths = [] - for filename in paths: - path = Path(filename) + for path_str in paths: + path = Path(path_str) if not cls.save_tags: - cls.paths.append((path, None, None, None)) + cls.paths.append([path_str, '', '']) continue # guess the output path base_dir_last_idx = path.parts.index(cls.base_dir_last) # format output filename - info = tags_format.Info(path, 'txt') - fmt = partial(lambda info, m: tags_format.parse(m, info), info) - msg = 'Invalid output format' cls.err.discard(msg) try: - formatted_output_filename = tags_format.pattern.sub( - fmt, - Its.output_filename_format - ) + formatted_output_filename = format_output_filename( + path, format=Its.output_filename_format) except (TypeError, ValueError): cls.err.add(msg) - if cls.output_root is None: - raise ValueError + if not cls.output_root: + raise ValueError('output_root not set') output_dir = cls.output_root.joinpath( *path.parts[base_dir_last_idx + 1:]).parent tags_out = output_dir.joinpath(formatted_output_filename) if output_dir in checked_dirs: - cls.paths.append((path, tags_out, None, None)) + cls.paths.append([path, tags_out, '']) else: checked_dirs.add(output_dir) if os.path.exists(output_dir): msg = 'output_dir: not a directory.' if os.path.isdir(output_dir): - cls.paths.append((path, tags_out, None, None)) + cls.paths.append([path, tags_out, '']) cls.err.discard(msg) else: cls.err.add(msg) else: - cls.paths.append((path, tags_out, output_dir, None)) + cls.paths.append([path, tags_out, output_dir]) class QData: """ Query data: contains parameters for the query """ add_tags: List[str] = [] - keep_tags: Set[str] = set() + keep_tags = set() exclude_tags: List[str] = [] - search_tags: Dict[int, Pattern[str]] = {} + search_tags: Dict[str, str] = {} replace_tags: List[str] = [] threshold = 0.35 tag_frac_threshold = 0.05 + count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100) # read from db.json, update with what should be written to db.json: json_db = None - weighed: Tuple[ - Dict[str, List[float]], - Dict[str, List[float]] - ] = (defaultdict(list), defaultdict(list)) - query: Dict[str, Tuple[str, int]] = {} + weighed: Tuple[Dict[str, list], Dict[str, list]] = \ + (defaultdict(list), defaultdict(list)) + query = {} # representing the (cumulative) current interrogations - ratings: Dict[str, float] = defaultdict(float) - tags: Dict[str, List[float]] = defaultdict(list) - discarded_tags: Dict[str, List[float]] = defaultdict(list) - in_db: Dict[ - int, - Tuple[str, str, str, Dict[str, float], Dict[str, float]] - ] = {} - for_tags_file: Dict[ - str, Dict[str, float] - ] = defaultdict(lambda: defaultdict(float)) + ratings = defaultdict(float) + tags = defaultdict(list) + discarded_tags = defaultdict(list) + in_db = {} + for_tags_file = defaultdict(lambda: defaultdict(float)) had_new = False - err: Set[str] = set() - image_dups: Dict[str, Set[str]] = defaultdict(set) + err = set() + image_dups = defaultdict(set) + + @classmethod + def set(cls, key: str) -> Callable[[str], Tuple[str]]: + def setter(val) -> Tuple[str]: + setattr(cls, key, val) + return setter @classmethod def set(cls, key: str) -> Callable[[str], Tuple[str]]: @@ -242,8 +243,8 @@ def setter(val) -> Tuple[str]: def clear(cls, mode: int) -> None: """ clear tags and ratings """ cls.tags.clear() - cls.discarded_tags.clear() cls.ratings.clear() + cls.discarded_tags.clear() cls.for_tags_file.clear() if mode > 0: cls.in_db.clear() @@ -323,12 +324,11 @@ def update_add(cls, add: str) -> None: cls.test_add(tag, 'add', ['exclude', 'search']) # silently raise count threshold to avoid issue in apply_filters - count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100) - if len(cls.add_tags) > count_threshold: - shared.opts.tagger_count_threshold = len(cls.add_tags) + if len(cls.add_tags) > cls.count_threshold: + cls.count_threshold = len(cls.add_tags) @staticmethod - def compile_rex(rex: str) -> Optional[Pattern[str]]: + def compile_rex(rex: str) -> Optional: if rex in {'', '^', '$', '^$'}: return None if rex[0] == '^': @@ -381,7 +381,7 @@ def update_replace(cls, replace: str) -> None: cls.err.discard(msg) @classmethod - def get_i_wt(cls, stored: float) -> Tuple[int, float]: + def get_i_wt(cls, stored: int) -> Tuple[int, float]: """ in db.json or QData.weighed, the weights & increment in the list are encoded. Each filestamp-interrogation corresponds to an incrementing @@ -404,9 +404,8 @@ def read_json(cls, outdir) -> None: # validate json using either json_schema/db_jon_v1_schema.json # or json_schema/db_jon_v2_schema.json - schema = Path(__file__).parent.parent.joinpath( - 'json_schema', 'db_json_v1_schema.json' - ) + schema = root_dir.joinpath('json_schema', + 'db_json_v1_schema.json') try: data = loads(cls.json_db.read_text()) validate(data, loads(schema.read_text())) @@ -455,8 +454,8 @@ def get_index(cls, fi_key: str, path='') -> int: @classmethod def single_data(cls, fi_key: str) -> None: """ get tags and ratings for filestamp-interrogator """ - index = cls.query[fi_key][1] - data: Tuple[Dict[str, float], Dict[str, float]] = ({}, {}) + index = cls.query.get(fi_key)[1] + data = ({}, {}) for j in range(2): for ent, lst in cls.weighed[j].items(): for i, val in map(cls.get_i_wt, lst): @@ -480,12 +479,11 @@ def correct_tag(cls, tag: str) -> str: if getattr(shared.opts, 'tagger_escape', False): tag = re_special.sub(r'\\\1', tag) # tag_escape_pattern - if len(cls.search_tags) != len(cls.replace_tags): + if len(cls.search_tags) == len(cls.replace_tags): for i, regex in cls.search_tags.items(): - m = re_match(regex, tag) - if m: - return re_sub(regex, cls.replace_tags[i], tag) - + if re_match(regex, tag): + tag = re_sub(regex, cls.replace_tags[i], tag) + break return tag @classmethod @@ -506,8 +504,7 @@ def apply_filters(cls, data) -> None: cls.weighed[0][rating].append(val + index) cls.ratings[rating] += val - count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100) - max_ct = count_threshold - len(cls.add_tags) + max_ct = cls.count_threshold - len(cls.add_tags) count = 0 # loop over tags with db update for tag, val in tags: @@ -625,10 +622,10 @@ def finalize(cls, count: int) -> ItRetTP: for file, remaining_tags in cls.for_tags_file.items(): sorted_tags = cls.sort_tags(remaining_tags) if weighted_tags_files: - joinable = [f'({k}:{v})' for k, v in sorted_tags] + sorted_tags = [f'({k}:{v})' for k, v in sorted_tags] else: - joinable = [k for k, v in sorted_tags] - Path(file).write_text(', '.join(joinable), encoding='utf-8') + sorted_tags = [k for k, v in sorted_tags] + file.write_text(', '.join(sorted_tags), encoding='utf-8') warn = "" if len(QData.err) > 0: diff --git a/tagger/utils.py b/tagger/utils.py deleted file mode 100644 index e30101c..0000000 --- a/tagger/utils.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Utility functions for the tagger module""" -import os - -from typing import List, Dict -from pathlib import Path - -from modules import shared, scripts # pylint: disable=import-error -from modules.shared import models_path # pylint: disable=import-error - -default_ddp_path = Path(models_path, 'deepdanbooru') -default_onnx_path = Path(models_path, 'TaggerOnnx') -from tagger.preset import Preset # pylint: disable=import-error -from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, \ - MLDanbooruInterrogator # pylint: disable=E0401 # noqa: E501 -from tagger.interrogator import WaifuDiffusionInterrogator # pylint: disable=E0401 # noqa: E501 - -preset = Preset(Path(scripts.basedir(), 'presets')) - -interrogators: Dict[str, Interrogator] = { - 'wd14-vit.v1': WaifuDiffusionInterrogator( - 'WD14 ViT v1', - repo_id='SmilingWolf/wd-v1-4-vit-tagger' - ), - 'wd14-vit.v2': WaifuDiffusionInterrogator( - 'WD14 ViT v2', - repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2', - ), - 'wd14-convnext.v1': WaifuDiffusionInterrogator( - 'WD14 ConvNeXT v1', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger' - ), - 'wd14-convnext.v2': WaifuDiffusionInterrogator( - 'WD14 ConvNeXT v2', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2', - ), - 'wd14-convnextv2.v1': WaifuDiffusionInterrogator( - 'WD14 ConvNeXTV2 v1', - # the name is misleading, but it's v1 - repo_id='SmilingWolf/wd-v1-4-convnextv2-tagger-v2', - ), - 'wd14-swinv2-v1': WaifuDiffusionInterrogator( - 'WD14 SwinV2 v1', - # again misleading name - repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2', - ), - 'wd-v1-4-moat-tagger.v2': WaifuDiffusionInterrogator( - 'WD14 moat tagger v2', - repo_id='SmilingWolf/wd-v1-4-moat-tagger-v2' - ), - 'mld-caformer.dec-5-97527': MLDanbooruInterrogator( - 'ML-Danbooru Caformer dec-5-97527', - repo_id='deepghs/ml-danbooru-onnx', - model_path='ml_caformer_m36_dec-5-97527.onnx' - ), - 'mld-tresnetd.6-30000': MLDanbooruInterrogator( - 'ML-Danbooru TResNet-D 6-30000', - repo_id='deepghs/ml-danbooru-onnx', - model_path='TResnet-D-FLq_ema_6-30000.onnx' - ), -} - - -def refresh_interrogators() -> List[str]: - """Refreshes the interrogators list""" - # load deepdanbooru project - ddp_path = shared.cmd_opts.deepdanbooru_projects_path - if ddp_path is None: - ddp_path = default_ddp_path - onnx_path = shared.cmd_opts.onnxtagger_path - if onnx_path is None: - onnx_path = default_onnx_path - os.makedirs(ddp_path, exist_ok=True) - os.makedirs(onnx_path, exist_ok=True) - - for path in os.scandir(ddp_path): - print(f"Scanning {path} as deepdanbooru project") - if not path.is_dir(): - print(f"Warning: {path} is not a directory, skipped") - continue - - if not Path(path, 'project.json').is_file(): - print(f"Warning: {path} has no project.json, skipped") - continue - - interrogators[path.name] = DeepDanbooruInterrogator(path.name, path) - # scan for onnx models as well - for path in os.scandir(onnx_path): - print(f"Scanning {path} as onnx model") - if not path.is_dir(): - print(f"Warning: {path} is not a directory, skipped") - continue - - onnx_files = [x for x in os.scandir(path) if x.name.endswith('.onnx')] - if len(onnx_files) != 1: - print(f"Warning: {path} requires exactly one .onnx model, skipped") - continue - local_path = Path(path, onnx_files[0].name) - - csv = [x for x in os.scandir(path) if x.name.endswith('.csv')] - if len(csv) == 0: - print(f"Warning: {path} has no selected tags .csv file, skipped") - continue - - def tag_select_csvs_up_front(k): - sum(-1 if t in k.name.lower() else 1 for t in ["tag", "select"]) - - csv.sort(key=tag_select_csvs_up_front) - tags_path = Path(path, csv[0]) - - if path.name not in interrogators: - if path.name == 'wd-v1-4-convnextv2-tagger-v2': - interrogators[path.name] = WaifuDiffusionInterrogator( - path.name, - repo_id='SmilingWolf/SW-CV-ModelZoo', - is_hf=False - ) - elif path.name == 'Z3D-E621-Convnext': - interrogators[path.name] = WaifuDiffusionInterrogator( - 'Z3D-E621-Convnext', is_hf=False) - else: - raise NotImplementedError(f"Add {path.name} resolution similar" - "to above here") - - interrogators[path.name].local_model = str(local_path) - interrogators[path.name].local_tags = str(tags_path) - - return sorted(interrogators.keys()) - - -def split_str(string: str, separator=',') -> List[str]: - return [x.strip() for x in string.split(separator) if x]