Skip to content

Commit c880539

Browse files
committed
Allow --clip-on-cpu and --control-net-cpu
1 parent 108c750 commit c880539

4 files changed

Lines changed: 25 additions & 1 deletion

File tree

include/image_generator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ class ImageGenerator {
115115
bool vae_tiling_;
116116
bool offload_to_cpu_;
117117
bool diffusion_fa_;
118+
bool control_net_cpu_;
119+
bool clip_on_cpu_;
118120
};
119121

120122
#endif // __IMAGE_GENERATOR_H__

include/server.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct ServerParams {
1818
bool vae_tiling = false;
1919
bool offload_to_cpu = false;
2020
bool diffusion_fa = false;
21+
bool control_net_cpu = false;
22+
bool clip_on_cpu = false;
2123
};
2224

2325
class Server {

src/image_generator.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ ImageGenerator::ImageGenerator(std::shared_ptr<TaskStateManager> task_state_mana
8585
vae_on_cpu_(server_params.vae_on_cpu),
8686
vae_tiling_(server_params.vae_tiling),
8787
offload_to_cpu_(server_params.offload_to_cpu),
88-
diffusion_fa_(server_params.diffusion_fa) {
88+
diffusion_fa_(server_params.diffusion_fa),
89+
control_net_cpu_(server_params.control_net_cpu),
90+
clip_on_cpu_(server_params.clip_on_cpu) {
8991
LOG_INFO("ImageGenerator created");
9092
}
9193

@@ -546,6 +548,8 @@ bool ImageGenerator::ensureModelLoaded(const std::string& controlnet_model) {
546548
params.keep_vae_on_cpu = vae_on_cpu_;
547549
params.offload_params_to_cpu = offload_to_cpu_;
548550
params.diffusion_flash_attn = diffusion_fa_;
551+
params.keep_control_net_on_cpu = control_net_cpu_;
552+
params.keep_clip_on_cpu = clip_on_cpu_;
549553

550554
if (vae_on_cpu_) {
551555
LOG_INFO("VAE will be kept on CPU");
@@ -556,6 +560,12 @@ bool ImageGenerator::ensureModelLoaded(const std::string& controlnet_model) {
556560
if (diffusion_fa_) {
557561
LOG_INFO("Diffusion flash attention enabled");
558562
}
563+
if (control_net_cpu_) {
564+
LOG_INFO("ControlNet will be kept on CPU");
565+
}
566+
if (clip_on_cpu_) {
567+
LOG_INFO("CLIP will be kept on CPU");
568+
}
559569

560570
// Create SD context
561571
sd_ctx_ = new_sd_ctx(&params);

src/main.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ struct CommandLineArgs {
9393
bool vae_tiling = false;
9494
bool offload_to_cpu = false;
9595
bool diffusion_fa = false;
96+
bool control_net_cpu = false;
97+
bool clip_on_cpu = false;
9698
};
9799

98100
void print_usage(const char* program_name) {
@@ -117,6 +119,8 @@ void print_usage(const char* program_name) {
117119
std::cerr << " --vae-tiling Enable VAE tiling (default: false)" << std::endl;
118120
std::cerr << " --offload-to-cpu Offload parameters to CPU (default: false)" << std::endl;
119121
std::cerr << " --diffusion-fa Enable diffusion flash attention (default: false)" << std::endl;
122+
std::cerr << " --control-net-cpu Keep ControlNet on CPU (default: false)" << std::endl;
123+
std::cerr << " --clip-on-cpu Keep CLIP on CPU (default: false)" << std::endl;
120124
}
121125

122126
CommandLineArgs parse_args(int argc, char* argv[]) {
@@ -171,6 +175,10 @@ CommandLineArgs parse_args(int argc, char* argv[]) {
171175
args.offload_to_cpu = true;
172176
} else if (arg == "--diffusion-fa") {
173177
args.diffusion_fa = true;
178+
} else if (arg == "--control-net-cpu") {
179+
args.control_net_cpu = true;
180+
} else if (arg == "--clip-on-cpu") {
181+
args.clip_on_cpu = true;
174182
} else if (arg == "--help" || arg == "-h") {
175183
print_usage(argv[0]);
176184
exit(0);
@@ -249,6 +257,8 @@ int main(int argc, char* argv[]) {
249257
server_params.vae_tiling = args.vae_tiling;
250258
server_params.offload_to_cpu = args.offload_to_cpu;
251259
server_params.diffusion_fa = args.diffusion_fa;
260+
server_params.control_net_cpu = args.control_net_cpu;
261+
server_params.clip_on_cpu = args.clip_on_cpu;
252262

253263
// Create and start the server
254264
g_server = std::make_unique<Server>(server_params);

0 commit comments

Comments
 (0)