Skip to content

Commit 5ece99b

Browse files
authored
Merge pull request #108 from newfla/feat_hidream_O1_image
feat: add HiDream O1 image presets and configuration functions
2 parents 445d610 + 5646d4e commit 5ece99b

5 files changed

Lines changed: 51 additions & 9 deletions

File tree

cli/src/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ fn get_preset(args: &Args) -> Preset {
410410
.try_into()
411411
.unwrap(),
412412
),
413+
PresetDiscriminants::HiDreamO1ImageDev => Preset::HiDreamO1ImageDev,
414+
PresetDiscriminants::HiDreamO1Image => Preset::HiDreamO1Image,
413415
};
414416
preset
415417
}

src/api.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,13 @@ pub struct ModelConfig {
489489
#[builder(default = "None", private)]
490490
diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
491491

492+
/// Hires fix parameters and upscaler model.
492493
#[builder(default = "Self::hires_init()", setter(custom))]
493494
hires_params: (Upscaler, HiresParams, Option<CLibPath>),
495+
496+
/// Extra parameters for sampling, currently used for SDXL sample params, in json string format
497+
#[builder(default = "CLibString::default()")]
498+
extra_sample_params: CLibString,
494499
}
495500

496501
impl ModelConfigBuilder {
@@ -780,7 +785,8 @@ impl From<&ModelConfig> for ModelConfigBuilder {
780785
value.hires_params.0,
781786
value.hires_params.1.clone(),
782787
hires_path.as_deref(),
783-
);
788+
)
789+
.extra_sample_params(value.extra_sample_params.clone());
784790

785791
builder.lora_models_internal(value.lora_models.clone());
786792

@@ -1248,6 +1254,7 @@ fn gen_img_maybe_progress(
12481254
custom_sigmas: model_config.sigmas.as_mut_ptr(),
12491255
custom_sigmas_count: model_config.sigmas.len() as i32,
12501256
flow_shift: model_config.flow_shift,
1257+
extra_sample_args: model_config.extra_sample_params.as_ptr(),
12511258
};
12521259
let control_image = sd_image_t {
12531260
width: 0,

src/preset.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ use crate::{
99
anima, anima2, chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo,
1010
ernie_image, ernie_image_turbo, flux_1_dev, flux_1_mini, flux_1_schnell, flux_2_dev,
1111
flux_2_klein_4b, flux_2_klein_9b, flux_2_klein_base_4b, flux_2_klein_base_9b,
12-
juggernaut_xl_11, nitro_sd_realism, nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo,
13-
sdxl_base_1_0, sdxl_turbo_1_0, sdxs512_dream_shaper, segmind_vega, ssd_1b,
14-
stable_diffusion_1_4, stable_diffusion_1_5, stable_diffusion_2_1,
15-
stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo, stable_diffusion_3_5_medium,
16-
stable_diffusion_3_medium, twinflow_z_image_turbo, z_image_turbo,
12+
hi_dream_o1_image, hi_dream_o1_image_dev, juggernaut_xl_11, nitro_sd_realism,
13+
nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0,
14+
sdxs512_dream_shaper, segmind_vega, ssd_1b, stable_diffusion_1_4, stable_diffusion_1_5,
15+
stable_diffusion_2_1, stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo,
16+
stable_diffusion_3_5_medium, stable_diffusion_3_medium, twinflow_z_image_turbo,
17+
z_image_turbo,
1718
},
1819
};
1920

@@ -300,14 +301,18 @@ pub enum Preset {
300301
Flux2KleinBase9B(Flux2KleinBase9BWeight),
301302
/// guidance_scale 9. 25 steps. 1024x1024
302303
SegmindVega,
303-
/// cfg__scale 4.0. 30 steps 1024x1024. Vae tiling enabled
304+
/// cfg_scale 4.0. 30 steps 1024x1024. Vae tiling enabled
304305
Anima(AnimaWeight),
305-
/// cfg__scale 4.0. 30 steps 1024x1024. Vae tiling enabled
306+
/// cfg_scale 4.0. 30 steps 1024x1024. Vae tiling enabled
306307
Anima2(Anima2Weight),
307308
/// cfg_scale 5.0. 20 steps 1024x1024. Vae tiling enabled. Flash attention enabled.
308309
ErnieImage(ErnieImageWeight),
309310
/// cfg_scale 1.0. 8 steps 1024x1024. Vae tiling enabled. Flash attention enabled.
310311
ErnieImageTurbo(ErnieImageWeight),
312+
/// cfg_scale 1.0. 20 steps 1024x1024.
313+
HiDreamO1ImageDev,
314+
/// cfg_scale 1.0. 20 steps 1024x1024.
315+
HiDreamO1Image,
311316
}
312317

313318
impl Preset {
@@ -349,6 +354,8 @@ impl Preset {
349354
Preset::Anima2(sd_type_t) => anima2(sd_type_t),
350355
Preset::ErnieImage(sd_type_t) => ernie_image(sd_type_t),
351356
Preset::ErnieImageTurbo(sd_type_t) => ernie_image_turbo(sd_type_t),
357+
Preset::HiDreamO1ImageDev => hi_dream_o1_image_dev(),
358+
Preset::HiDreamO1Image => hi_dream_o1_image(),
352359
}
353360
}
354361
}

src/preset_builder.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,3 +1521,29 @@ fn ernie_image_llm(sd_type: ErnieImageWeight) -> Result<PathBuf, ApiError> {
15211521
),
15221522
}
15231523
}
1524+
1525+
pub fn hi_dream_o1_image_dev() -> Result<ConfigsBuilder, ApiError> {
1526+
let model = download_file_hf_hub(
1527+
"Comfy-Org/HiDream-O1-Image",
1528+
"checkpoints/hidream_o1_image_dev_bf16.safetensors",
1529+
)?;
1530+
let mut config = ConfigBuilder::default();
1531+
let mut model_config = ModelConfigBuilder::default();
1532+
model_config.model(model);
1533+
config.cfg_scale(1.0).steps(20).height(1024).width(1024);
1534+
1535+
Ok((config, model_config))
1536+
}
1537+
1538+
pub fn hi_dream_o1_image() -> Result<ConfigsBuilder, ApiError> {
1539+
let model = download_file_hf_hub(
1540+
"Comfy-Org/HiDream-O1-Image",
1541+
"checkpoints/hidream_o1_image_bf16.safetensors",
1542+
)?;
1543+
let mut config = ConfigBuilder::default();
1544+
let mut model_config = ModelConfigBuilder::default();
1545+
model_config.model(model);
1546+
config.cfg_scale(1.0).steps(20).height(1024).width(1024);
1547+
1548+
Ok((config, model_config))
1549+
}

0 commit comments

Comments
 (0)