Skip to content

Commit 1a7de0d

Browse files
committed
Add --chroma-disable-dit-mask
1 parent 73b3415 commit 1a7de0d

4 files changed

Lines changed: 14 additions & 1 deletion

File tree

include/image_generator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class ImageGenerator {
118118
bool diffusion_fa_;
119119
bool control_net_cpu_;
120120
bool clip_on_cpu_;
121+
bool chroma_disable_dit_mask_;
121122
};
122123

123124
#endif // __IMAGE_GENERATOR_H__

include/server.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct ServerParams {
2121
bool diffusion_fa = false;
2222
bool control_net_cpu = false;
2323
bool clip_on_cpu = false;
24+
bool chroma_disable_dit_mask = false;
2425
};
2526

2627
class Server {

src/image_generator.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ ImageGenerator::ImageGenerator(std::shared_ptr<TaskStateManager> task_state_mana
8989
offload_to_cpu_(server_params.offload_to_cpu),
9090
diffusion_fa_(server_params.diffusion_fa),
9191
control_net_cpu_(server_params.control_net_cpu),
92-
clip_on_cpu_(server_params.clip_on_cpu) {
92+
clip_on_cpu_(server_params.clip_on_cpu),
93+
chroma_disable_dit_mask_(server_params.chroma_disable_dit_mask) {
9394
LOG_INFO("ImageGenerator created");
9495
}
9596

@@ -578,6 +579,7 @@ bool ImageGenerator::ensureModelLoaded(const std::string& controlnet_model) {
578579
params.diffusion_flash_attn = diffusion_fa_;
579580
params.keep_control_net_on_cpu = control_net_cpu_;
580581
params.keep_clip_on_cpu = clip_on_cpu_;
582+
params.chroma_use_dit_mask = !chroma_disable_dit_mask_;
581583

582584
if (vae_on_cpu_) {
583585
LOG_INFO("VAE will be kept on CPU");
@@ -594,6 +596,9 @@ bool ImageGenerator::ensureModelLoaded(const std::string& controlnet_model) {
594596
if (clip_on_cpu_) {
595597
LOG_INFO("CLIP will be kept on CPU");
596598
}
599+
if (chroma_disable_dit_mask_) {
600+
LOG_INFO("DiT mask disabled for Chroma models");
601+
}
597602

598603
// Create SD context
599604
sd_ctx_ = new_sd_ctx(&params);

src/main.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ struct CommandLineArgs {
9696
bool diffusion_fa = false;
9797
bool control_net_cpu = false;
9898
bool clip_on_cpu = false;
99+
bool chroma_disable_dit_mask = false;
99100
};
100101

101102
void print_usage(const char* program_name) {
@@ -124,6 +125,8 @@ void print_usage(const char* program_name) {
124125
std::cerr << " --diffusion-fa Enable diffusion flash attention (default: false)" << std::endl;
125126
std::cerr << " --control-net-cpu Keep ControlNet on CPU (default: false)" << std::endl;
126127
std::cerr << " --clip-on-cpu Keep CLIP on CPU (default: false)" << std::endl;
128+
std::cerr << " --chroma-disable-dit-mask Disable DiT mask for Chroma models (default: false)"
129+
<< std::endl;
127130
}
128131

129132
CommandLineArgs parse_args(int argc, char* argv[]) {
@@ -184,6 +187,8 @@ CommandLineArgs parse_args(int argc, char* argv[]) {
184187
args.control_net_cpu = true;
185188
} else if (arg == "--clip-on-cpu") {
186189
args.clip_on_cpu = true;
190+
} else if (arg == "--chroma-disable-dit-mask") {
191+
args.chroma_disable_dit_mask = true;
187192
} else if (arg == "--help" || arg == "-h") {
188193
print_usage(argv[0]);
189194
exit(0);
@@ -265,6 +270,7 @@ int main(int argc, char* argv[]) {
265270
server_params.diffusion_fa = args.diffusion_fa;
266271
server_params.control_net_cpu = args.control_net_cpu;
267272
server_params.clip_on_cpu = args.clip_on_cpu;
273+
server_params.chroma_disable_dit_mask = args.chroma_disable_dit_mask;
268274

269275
// Create and start the server
270276
g_server = std::make_unique<Server>(server_params);

0 commit comments

Comments
 (0)