Skip to content

Commit f3696ce

Browse files
committed
Accept LoRA paths and alphas directly in the API; rather than extracting from the prompt
1 parent 6964fc8 commit f3696ce

3 files changed

Lines changed: 25 additions & 151 deletions

File tree

include/image_generator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct ServerParams;
1818
struct ImageGenerationParams {
1919
std::string prompt;
2020
std::string negative_prompt;
21+
std::vector<std::string> lora_paths;
22+
std::vector<float> lora_alphas;
2123
int width = 512;
2224
int height = 512;
2325
int steps = 20;

src/image_generator.cpp

Lines changed: 13 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <cstring>
44
#include <filesystem>
55
#include <map>
6-
#include <regex>
76
#include <stdexcept>
87
#include <string>
98

@@ -108,135 +107,6 @@ ImageGenerator::~ImageGenerator() {
108107
}
109108
}
110109

111-
// Helper function to find a file recursively by stem and extension
112-
static fs::path findFileRecursively(const fs::path& dir, const std::string& stem,
113-
const std::vector<std::string>& exts) {
114-
for (auto& p : fs::recursive_directory_iterator(dir)) {
115-
if (!p.is_regular_file()) continue;
116-
117-
auto path = p.path();
118-
// Use UTF-8 safe conversion for path comparison
119-
std::string path_stem = path_to_utf8(path.stem());
120-
if (path_stem == stem) {
121-
std::string ext = path_to_utf8(path.extension());
122-
for (auto& e : exts) {
123-
if (ext == e) {
124-
return path;
125-
}
126-
}
127-
}
128-
}
129-
return {};
130-
}
131-
132-
// Helper function to parse loras from prompt and return cleaned prompt
133-
static std::string parseLoras(const std::string& prompt, const std::string& lora_model_dir,
134-
std::map<std::string, float>& lora_map,
135-
std::map<std::string, float>& high_noise_lora_map) {
136-
if (lora_model_dir.empty()) {
137-
return prompt; // Return original prompt if no lora directory
138-
}
139-
140-
// Convert to wstring for proper UTF-8 handling
141-
std::wstring w_prompt = utf8_to_wstring(prompt);
142-
static const std::wregex re(L"(<lora:([^:>]+):([^>]+)>)");
143-
std::wsmatch m;
144-
std::wstring w_cleaned_prompt = w_prompt; // Start with original prompt
145-
146-
std::wstring w_tmp = w_prompt;
147-
148-
while (std::regex_search(w_tmp, m, re)) {
149-
std::wstring w_raw_path = m[2].str(); // Capture group 2 is the path
150-
const std::wstring w_raw_mul = m[3].str(); // Capture group 3 is the multiplier
151-
152-
std::string raw_path = wstring_to_utf8(w_raw_path);
153-
std::string raw_mul = wstring_to_utf8(w_raw_mul);
154-
155-
float mul = 0.f;
156-
try {
157-
mul = std::stof(raw_mul);
158-
} catch (...) {
159-
std::wstring w_suffix = m.suffix().str();
160-
w_tmp = w_suffix;
161-
w_cleaned_prompt = std::regex_replace(w_cleaned_prompt, re, L"", std::regex_constants::format_first_only);
162-
continue;
163-
}
164-
165-
bool is_high_noise = false;
166-
static const std::string prefix = "|high_noise|";
167-
if (raw_path.rfind(prefix, 0) == 0) {
168-
raw_path.erase(0, prefix.size());
169-
is_high_noise = true;
170-
}
171-
172-
// Build platform-correct paths from UTF-8 input to avoid mojibake on Windows
173-
#ifdef _WIN32
174-
fs::path raw_path_p = fs::path(utf8_to_wstring(raw_path));
175-
fs::path lora_dir_p = fs::path(utf8_to_wstring(lora_model_dir));
176-
#else
177-
fs::path raw_path_p = fs::path(reinterpret_cast<const char8_t*>(raw_path.c_str()));
178-
fs::path lora_dir_p = fs::path(reinterpret_cast<const char8_t*>(lora_model_dir.c_str()));
179-
#endif
180-
181-
fs::path final_path;
182-
if (raw_path_p.is_absolute()) {
183-
final_path = raw_path_p;
184-
} else {
185-
final_path = lora_dir_p / raw_path_p;
186-
}
187-
188-
// Log the resolved path for debugging
189-
LOG_DEBUG("Resolved LoRA path: %s", path_to_utf8(final_path.lexically_normal()).c_str());
190-
191-
if (!fs::exists(final_path)) {
192-
LOG_WARNING("LoRA file does not exist: %s", path_to_utf8(final_path).c_str());
193-
bool found = false;
194-
for (const auto& ext : {"gguf", "safetensors", "pt", "sft"}) {
195-
fs::path try_path = final_path;
196-
try_path += ".";
197-
try_path += ext;
198-
if (fs::exists(try_path)) {
199-
final_path = try_path;
200-
found = true;
201-
break;
202-
}
203-
}
204-
if (!found && !raw_path_p.is_absolute()) {
205-
std::string stem = path_to_utf8(raw_path_p.stem());
206-
fs::path found_path = findFileRecursively(lora_dir_p, stem, {"gguf", "safetensors", "pt", "sft"});
207-
if (!found_path.empty()) {
208-
final_path = found_path;
209-
found = true;
210-
}
211-
}
212-
213-
// Remove the matched tag from the cleaned prompt and advance the regex search to avoid infinite loop
214-
w_cleaned_prompt = std::regex_replace(w_cleaned_prompt, re, L"", std::regex_constants::format_first_only);
215-
std::wstring w_suffix = m.suffix().str();
216-
w_tmp = w_suffix;
217-
218-
if (!found) {
219-
LOG_ERROR("Failed to find LoRA file: %s", raw_path.c_str());
220-
continue;
221-
}
222-
}
223-
224-
const std::string key = path_to_utf8(final_path.lexically_normal());
225-
226-
if (is_high_noise)
227-
high_noise_lora_map[key] += mul;
228-
else
229-
lora_map[key] += mul;
230-
231-
w_cleaned_prompt = std::regex_replace(w_cleaned_prompt, re, L"", std::regex_constants::format_first_only);
232-
233-
std::wstring w_suffix = m.suffix().str();
234-
w_tmp = w_suffix;
235-
}
236-
237-
return wstring_to_utf8(w_cleaned_prompt);
238-
}
239-
240110
// Helper function to build embedding map from directory
241111
static void buildEmbeddingMap(const std::string& embedding_dir, std::map<std::string, std::string>& embedding_map) {
242112
static const std::vector<std::string> valid_ext = {".gguf", ".safetensors", ".pt", ".sft"};
@@ -332,35 +202,26 @@ std::vector<std::string> ImageGenerator::generateInternal(const ImageGenerationP
332202
LOG_INFO("Generating %s: prompt='%s', size=%dx%d, steps=%d, seed=%lld", is_img2img ? "img2img" : "txt2img",
333203
params.prompt.c_str(), params.width, params.height, params.steps, params.seed);
334204

335-
// Parse loras from prompt and get cleaned prompt
336-
std::map<std::string, float> lora_map;
337-
std::map<std::string, float> high_noise_lora_map;
338205
std::vector<sd_lora_t> lora_vec;
339-
std::string lora_dir_str = current_lora_model_dir_;
340-
341-
std::string cleaned_prompt = parseLoras(params.prompt, lora_dir_str, lora_map, high_noise_lora_map);
342-
343-
// Build lora vector from parsed maps
344-
for (const auto& kv : lora_map) {
345-
sd_lora_t item;
346-
item.is_high_noise = false;
347-
item.path = kv.first.c_str();
348-
item.multiplier = kv.second;
349-
lora_vec.push_back(item);
206+
size_t lora_count = params.lora_paths.size();
207+
if (params.lora_alphas.size() < lora_count) {
208+
LOG_WARNING("LoRA alpha count (%zu) does not match LoRA path count (%zu)", params.lora_alphas.size(),
209+
params.lora_paths.size());
210+
lora_count = params.lora_alphas.size();
350211
}
351212

352-
for (const auto& kv : high_noise_lora_map) {
213+
for (size_t i = 0; i < lora_count; i++) {
353214
sd_lora_t item;
354-
item.is_high_noise = true;
355-
item.path = kv.first.c_str();
356-
item.multiplier = kv.second;
215+
item.is_high_noise = false;
216+
item.path = params.lora_paths[i].c_str();
217+
item.multiplier = params.lora_alphas[i];
357218
lora_vec.push_back(item);
358219
}
359220

360221
if (!lora_vec.empty()) {
361222
LOG_INFO("Using %zu LoRA(s)", lora_vec.size());
362223
for (const auto& lora : lora_vec) {
363-
LOG_DEBUG(" LoRA: %s (%.2f) %s", lora.path, lora.multiplier, lora.is_high_noise ? "[high_noise]" : "");
224+
LOG_DEBUG(" LoRA: %s (%.2f)", lora.path, lora.multiplier);
364225
}
365226
}
366227

@@ -372,7 +233,7 @@ std::vector<std::string> ImageGenerator::generateInternal(const ImageGenerationP
372233
gen_params.loras = lora_vec.empty() ? nullptr : lora_vec.data();
373234
gen_params.lora_count = static_cast<uint32_t>(lora_vec.size());
374235

375-
gen_params.prompt = cleaned_prompt.c_str();
236+
gen_params.prompt = params.prompt.c_str();
376237
gen_params.negative_prompt = params.negative_prompt.empty() ? "" : params.negative_prompt.c_str();
377238
gen_params.clip_skip = params.clip_skip;
378239
gen_params.width = params.width;
@@ -630,7 +491,8 @@ std::vector<sd_image_t> ImageGenerator::createRefImages(const ImageGenerationPar
630491
}
631492
}
632493

633-
LOG_DEBUG("Loaded reference image %zu: %ux%u channels", i, ref_image.width, ref_image.height, ref_image.channel);
494+
LOG_DEBUG("Loaded reference image %zu: %ux%u channels", i, ref_image.width, ref_image.height,
495+
ref_image.channel);
634496
ref_images.push_back(ref_image);
635497
}
636498

src/server.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,16 @@ crow::response Server::generateImage(const crow::json::rvalue& json_body, bool i
275275
ImageGenerationParams params;
276276
params.prompt = json_body.has("prompt") ? std::string(json_body["prompt"].s()) : "";
277277
params.negative_prompt = json_body.has("negative_prompt") ? std::string(json_body["negative_prompt"].s()) : "";
278+
if (json_body.has("lora_paths") && json_body["lora_paths"].t() == crow::json::type::List) {
279+
for (size_t i = 0; i < json_body["lora_paths"].size(); i++) {
280+
params.lora_paths.push_back(std::string(json_body["lora_paths"][i].s()));
281+
}
282+
}
283+
if (json_body.has("lora_alphas") && json_body["lora_alphas"].t() == crow::json::type::List) {
284+
for (size_t i = 0; i < json_body["lora_alphas"].size(); i++) {
285+
params.lora_alphas.push_back(static_cast<float>(json_body["lora_alphas"][i].d()));
286+
}
287+
}
278288
params.width = json_body.has("width") ? json_body["width"].i() : 512;
279289
params.height = json_body.has("height") ? json_body["height"].i() : 512;
280290
params.steps = json_body.has("steps") ? json_body["steps"].i() : 20;

0 commit comments

Comments
 (0)