Skip to content

Commit 41e1003

Browse files
authored
avoid hardcode device in flux-control example (#13336)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 85ffcf1 commit 41e1003

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/flux-control/train_control_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11051105

11061106
# text encoding.
11071107
captions = batch["captions"]
1108-
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
1108+
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
11091109
with torch.no_grad():
11101110
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
11111111
captions, prompt_2=None

examples/flux-control/train_control_lora_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12511251

12521252
# text encoding.
12531253
captions = batch["captions"]
1254-
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
1254+
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
12551255
with torch.no_grad():
12561256
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
12571257
captions, prompt_2=None

0 commit comments

Comments
 (0)