Skip to content

Commit f5282e1

Browse files
committed
allow ANY api field to have specified defaults, and to be overwritten by value specified at load time
1 parent 6548645 commit f5282e1

1 file changed

Lines changed: 41 additions & 47 deletions

File tree

koboldcpp.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,59 +1907,42 @@ def sd_comfyui_tranform_params(genparams):
19071907
print("Warning: ComfyUI Payload Missing!")
19081908
return genparams
19091909

1910-
def sd_process_meta_fields(fields):
1911-
# aliases to match sd.cpp command-line options
1912-
aliases = {
1910+
# json with top-level dict
1911+
def gendefaults_parse_meta_field(input_str):
1912+
alias_map = {
19131913
'cfg-scale': 'cfg_scale',
19141914
'guidance': 'distilled_guidance',
19151915
'sampler': 'sampler_name',
19161916
'sampling-method': 'sampler_name',
19171917
'timestep-shift': 'shifted_timestep',
19181918
}
1919-
fields_dict = {aliases.get(k, k): v for k, v in fields}
1920-
# whitelist accepted parameters
1921-
whitelist = ['scheduler', 'shifted_timestep', 'distilled_guidance', 'sampler_name', 'cfg_scale', 'add_sd_step_limit', 'add_sd_cfg_limit', 'remove_limits']
1922-
fields_dict = {k: v for k, v in fields_dict.items() if k in whitelist}
1923-
return fields_dict
1924-
1925-
# json with top-level dict
1926-
def sd_parse_meta_field(prompt):
1927-
jfields = {}
1928-
kv_dict = {}
1929-
try:
1919+
if not isinstance(input_str, str) or not input_str.strip():
1920+
return {}
1921+
parsed = None
1922+
try: # Try parsing as-is
1923+
parsed = json.loads(input_str)
1924+
except json.JSONDecodeError:
1925+
# Try wrapping in braces for loose key/value strings
19301926
try:
1931-
jfields = json.loads(prompt)
1927+
parsed = json.loads(f"{{{input_str}}}")
19321928
except json.JSONDecodeError:
1933-
# accept "field":"value",... without {} (also empty strings)
1934-
try:
1935-
jfields = json.loads('{ ' + prompt + ' }')
1936-
except json.JSONDecodeError:
1937-
print("Warning: couldn't parse meta prompt; it should be valid JSON.")
1938-
if not isinstance(jfields, dict):
1939-
jfields = {}
1940-
kv_dict = sd_process_meta_fields(jfields.items())
1941-
except Exception:
1942-
pass
1943-
return kv_dict
1944-
1929+
print("Warning: couldn't parse gendefaults_parse_meta_field.")
1930+
return {}
1931+
if not isinstance(parsed, dict):
1932+
print("Warning: gendefaults_parse_meta_field - not a JSON object.")
1933+
return {}
1934+
result = {}
1935+
# First pass: apply aliases only if canonical key is not explicitly present
1936+
for key, value in parsed.items():
1937+
canonical = alias_map.get(key, key)
1938+
if canonical not in parsed:
1939+
result[canonical] = value
1940+
result.update(parsed) # Second pass: explicit keys override aliases
1941+
return result
19451942

19461943
def sd_generate(genparams):
19471944
global maxctx, args, currentusergenkey, totalgens, pendingabortkey, chatcompl_adapter
19481945

1949-
sdgendefaults = sd_parse_meta_field(args.sdgendefaults or '')
1950-
params = dict()
1951-
defparams = dict()
1952-
for k, v in sdgendefaults.items():
1953-
if k in ['sampler_name', 'scheduler']:
1954-
# these can be explicitely set to 'default'; process later
1955-
# TODO should we consider values like 'clip_skip=-1' as 'default' too?
1956-
defparams[k] = v
1957-
else:
1958-
params[k] = v
1959-
# apply most of the defaults
1960-
params.update(genparams)
1961-
genparams = params
1962-
19631946
default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter
19641947
adapter_obj = genparams.get('adapter', default_adapter)
19651948
forced_negprompt = adapter_obj.get("add_sd_negative_prompt", "")
@@ -4520,6 +4503,10 @@ def do_POST(self):
45204503
}}).encode())
45214504
return
45224505

4506+
gendefaults = gendefaults_parse_meta_field(args.gendefaults or '')
4507+
gen_new_keys = {k: v for k, v in gendefaults.items() if k not in genparams}
4508+
genparams.update(gendefaults if args.gendefaultsoverwrite else gen_new_keys)
4509+
45234510
trunc_len = 8000
45244511
if args.debugmode >= 1:
45254512
trunc_len = 32000
@@ -5285,7 +5272,8 @@ def hide_tooltip(event):
52855272
sd_clamped_soft_var = ctk.StringVar(value="0")
52865273
sd_threads_var = ctk.StringVar(value=str(default_threads))
52875274
sd_quant_var = ctk.StringVar(value=sd_quant_choices[0])
5288-
sd_gen_defaults_var = ctk.StringVar()
5275+
gen_defaults_var = ctk.StringVar()
5276+
gen_defaults_overwrite_var = ctk.IntVar(value=0)
52895277

52905278
whisper_model_var = ctk.StringVar()
52915279
tts_model_var = ctk.StringVar()
@@ -5920,6 +5908,8 @@ def changerunmode(a,b,c):
59205908
context_var.trace_add("write", changed_gpulayers_estimate)
59215909
makelabelentry(context_tab, "Default Gen Amt:", defaultgenamt_var, row=20, padx=(120), singleline=True, tooltip="How many tokens to generate by default, if not specified. Must be smaller than context size. Usually, your frontend GUI will override this.")
59225910
makelabelentry(context_tab, "Prompt Limit:", genlimit_var, row=20, padx=(300), singleline=True, tooltip="If set, restricts max output tokens to this limit regardless of API request. Set to 0 to disable.",labelpadx=(210))
5911+
makelabelentry(context_tab, "Default Params:", gen_defaults_var, row=21, width=200, padx=(110), singleline=True, tooltip='Set default generation parameters for incoming API payloads.\nSpecified as JSON fields: {"KEY1":"VALUE1", "KEY2":"VALUE2"...}')
5912+
makecheckbox(context_tab, "Override", gen_defaults_overwrite_var, row=21,padx=(330), tooltiptxt="Allow the gendefaults parameters to overwrite the original value in API payloads.")
59235913

59245914
nativectx_entry, nativectx_label = makelabelentry(context_tab, "Override Native Context:", customrope_nativectx, row=23, padx=(146), singleline=True, tooltip="Overrides the native trained context of the loaded model with a custom value to be used for Rope scaling.")
59255915
customrope_scale_entry, customrope_scale_label = makelabelentry(context_tab, "RoPE Scale:", customrope_scale, row=23, padx=(100), singleline=True, tooltip="For Linear RoPE scaling. RoPE frequency scale.")
@@ -6086,7 +6076,6 @@ def toggletaesd(a,b,c):
60866076
makecheckbox(images_tab, "Model CPU Offload", sd_offload_cpu_var, 50,padx=8, tooltiptxt="Offload image weights in RAM to save VRAM, swap into VRAM when needed.")
60876077
makecheckbox(images_tab, "VAE on CPU", sd_vae_cpu_var, 50,padx=(160), tooltiptxt="Force VAE to CPU only for image generation.")
60886078
makecheckbox(images_tab, "CLIP on GPU", sd_clip_gpu_var, 50,padx=(280), tooltiptxt="Put CLIP and T5 to GPU for image generation. Otherwise, CLIP will use CPU.")
6089-
makelabelentry(images_tab, "Default Params:", sd_gen_defaults_var, 52, 280, padx=(110), singleline=True, tooltip='Default image generation parameters when not specified by the UI or API.\nSpecified as JSON fields: {"KEY1":"VALUE1", "KEY2":"VALUE2"...}')
60906079

60916080
# audio tab
60926081
audio_tab = tabcontent["Audio"]
@@ -6390,8 +6379,9 @@ def export_vars():
63906379
else:
63916380
args.sdlora = ""
63926381

6393-
if sd_gen_defaults_var.get() != "":
6394-
args.sdgendefaults = sd_gen_defaults_var.get()
6382+
if gen_defaults_var.get() != "":
6383+
args.gendefaults = gen_defaults_var.get()
6384+
args.gendefaultsoverwrite = (gen_defaults_overwrite_var.get()==1)
63956385

63966386
if whisper_model_var.get() != "":
63976387
args.whispermodel = whisper_model_var.get()
@@ -6623,7 +6613,8 @@ def import_vars(dict):
66236613

66246614
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
66256615
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
6626-
sd_gen_defaults_var.set(dict["sdgendefaults"] if ("sdgendefaults" in dict and dict["sdgendefaults"]) else "")
6616+
gen_defaults_var.set(dict["gendefaults"] if ("gendefaults" in dict and dict["gendefaults"]) else "")
6617+
gen_defaults_overwrite_var.set(1 if "gendefaultsoverwrite" in dict and dict["gendefaultsoverwrite"] else 0)
66276618

66286619
whisper_model_var.set(dict["whispermodel"] if ("whispermodel" in dict and dict["whispermodel"]) else "")
66296620

@@ -7008,6 +6999,8 @@ def convert_invalid_args(args):
70086999
dict["sdclip2"] = dict["sdclipg"]
70097000
if "jinja_tools" in dict and dict["jinja_tools"]:
70107001
dict["jinja"] = True
7002+
if "sdgendefaults" in dict and "gendefaults" not in dict:
7003+
dict["gendefaults"] = dict["sdgendefaults"]
70117004
return args
70127005

70137006
def setuptunnel(global_memory, has_sd):
@@ -8473,6 +8466,8 @@ def range_checker(arg: str):
84738466
compatgroup2.add_argument("--skiplauncher", help="Doesn't display or use the GUI launcher. Overrides showgui.", action='store_true')
84748467
advparser.add_argument("--singleinstance", help="Allows this KoboldCpp instance to be shut down by any new instance requesting the same port, preventing duplicate servers from clashing on a port.", action='store_true')
84758468
advparser.add_argument("--pipelineparallel", help="Enable Pipeline Parallelism for faster multigpu speeds but using more memory, only active for multigpu.", action='store_true')
8469+
advparser.add_argument("--gendefaults", metavar=('{"parameter":"value",...}'), help="Sets extra default parameters for some fields in API requests, as a JSON string.", default="")
8470+
advparser.add_argument("--gendefaultsoverwrite", help="Allow the gendefaults parameters to overwrite the original value in API payloads.", action='store_true')
84768471

84778472
hordeparsergroup = parser.add_argument_group('Horde Worker Commands')
84788473
hordeparsergroup.add_argument("--hordemodelname", metavar=('[name]'), help="Sets your AI Horde display model name.", default="")
@@ -8503,7 +8498,6 @@ def range_checker(arg: str):
85038498
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")
85048499
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the image LORA model to be applied.", type=float, default=1.0)
85058500
sdparsergroup.add_argument("--sdtiledvae", metavar=('[maxres]'), help="Adjust the automatic VAE tiling trigger for images above this size. 0 disables vae tiling.", type=int, default=default_vae_tile_threshold)
8506-
sdparsergroup.add_argument("--sdgendefaults", metavar=('{"parameter":"value",...}'), help="Sets default parameters for image generation, as a JSON string.", default="")
85078501
whisperparsergroup = parser.add_argument_group('Whisper Transcription Commands')
85088502
whisperparsergroup.add_argument("--whispermodel", metavar=('[filename]'), help="Specify a Whisper .bin model to enable Speech-To-Text transcription.", default="")
85098503

0 commit comments

Comments
 (0)