From 60fb1633eb0fce5732092ed8f431b2a20814ddaa Mon Sep 17 00:00:00 2001 From: Weiyang Jin <137654456+WayneJin0918@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:11:06 +0800 Subject: [PATCH 1/5] support only text tuning --- modeling/bagel/bagel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modeling/bagel/bagel.py b/modeling/bagel/bagel.py index 7ad8b5e8..55ae89fb 100644 --- a/modeling/bagel/bagel.py +++ b/modeling/bagel/bagel.py @@ -162,7 +162,7 @@ def forward( else: attention_mask = nested_attention_masks - if self.config.visual_und: + if self.config.visual_und and vit_token_seqlens is not None and len(vit_token_seqlens) > 0: cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) cu_seqlens = cu_seqlens.to(torch.int32) max_seqlen = torch.max(vit_token_seqlens).item() @@ -1029,4 +1029,4 @@ def chat( output = tokenizer.decode(unpacked_latent[:,0]) output = output.split('<|im_end|>')[0].split('<|im_start|>')[1] - return output \ No newline at end of file + return output From d667518918f6ff69f79e5c68f53ac49246fba736 Mon Sep 17 00:00:00 2001 From: Weiyang Jin <137654456+WayneJin0918@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:11:45 +0800 Subject: [PATCH 2/5] support batch inference --- inferencer.py | 177 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 159 insertions(+), 18 deletions(-) diff --git a/inferencer.py b/inferencer.py index 7f19d1ad..bbf4c849 100644 --- a/inferencer.py +++ b/inferencer.py @@ -36,6 +36,116 @@ def init_gen_context(self): } return gen_context + @torch.no_grad() + # MODIFIED: This method now processes a batch of texts. + def update_context_text_batch(self, texts: List[str], gen_context): + past_key_values = gen_context['past_key_values'] + kv_lens = gen_context['kv_lens'] + ropes = gen_context['ropes'] + + # The underlying prepare_prompts can handle multiple prompts. + generation_input, kv_lens, ropes = self.model.prepare_prompts( + curr_kvlens=kv_lens, + curr_rope=ropes, + prompts=texts, # Pass the list of texts directly + tokenizer=self.tokenizer, + new_token_ids=self.new_token_ids, + ) + + # Move all tensors in generation_input to the correct device + device = next(self.model.parameters()).device + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(device) + + past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input) + gen_context['kv_lens'] = kv_lens + gen_context['ropes'] = ropes + gen_context['past_key_values'] = past_key_values + + return gen_context + + @torch.no_grad() + # ADDED: New method dedicated to batched image generation. + def batch_gen_image( + self, + image_shapes: List[Tuple[int, int]], + gen_context, + cfg_text_scale=4.0, + num_timesteps=50, + timestep_shift=3.0, + **kwargs # Pass other hyperparams + ): + batch_size = len(image_shapes) + past_key_values = gen_context['past_key_values'] + kv_lens = gen_context['kv_lens'] + ropes = gen_context['ropes'] + + generation_input = self.model.prepare_vae_latent( + curr_kvlens=kv_lens, + curr_rope=ropes, + image_sizes=image_shapes, + new_token_ids=self.new_token_ids, + ) + + # For batched inference, we'll simplify CFG to only use the text condition. + # A deepcopy is essential to not corrupt the main context. + cfg_text_context = deepcopy(gen_context) + cfg_text_past_key_values = cfg_text_context['past_key_values'] + kv_lens_cfg = cfg_text_context['kv_lens'] + ropes_cfg = cfg_text_context['ropes'] + + generation_input_cfg_text = self.model.prepare_vae_latent_cfg( + curr_kvlens=kv_lens_cfg, + curr_rope=ropes_cfg, + image_sizes=image_shapes, + ) + + # Move all tensors to the correct device + device = next(self.model.parameters()).device + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(device) + for k, v in generation_input_cfg_text.items(): + if torch.is_tensor(v): + generation_input_cfg_text[k] = v.to(device) + + + unpacked_latent = self.model.generate_image( + past_key_values=past_key_values, + cfg_text_past_key_values=cfg_text_past_key_values, + cfg_img_past_key_values=None, # Simplified for batch T2I + num_timesteps=num_timesteps, + cfg_text_scale=cfg_text_scale, + cfg_img_scale=1.0, # Simplified for batch T2I + timestep_shift=timestep_shift, + **generation_input, + **generation_input_cfg_text, + # Pass through any other relevant kwargs from the call + **kwargs + ) + + # MODIFIED: Decode the entire batch of latents. + images = self.decode_image_batch(unpacked_latent, image_shapes) + return images + + # MODIFIED: Renamed from decode_image and now handles a batch. + def decode_image_batch(self, latents: List[torch.Tensor], image_shapes: List[Tuple[int,int]]): + images = [] + for i, latent in enumerate(latents): + H, W = image_shapes[i] + h, w = H // self.model.latent_downsample, W // self.model.latent_downsample + + latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel) + latent = torch.einsum("nhwpqc->nchpwq", latent) + latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size) + + image = self.vae_model.decode(latent) + image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255 + image = Image.fromarray((image).to(torch.uint8).cpu().numpy()) + images.append(image) + return images + @torch.no_grad() def update_context_text(self, text, gen_context): # used for interleave data, currently only support 1 data inference, @@ -281,29 +391,60 @@ def interleave_inference( return output_list + @torch.no_grad() + # MODIFIED: The __call__ method is now the main entry point for batching. def __call__( self, - image: Optional[Image.Image] = None, - text: Optional[str] = None, - **kargs + texts: Optional[List[str]] = None, + images: Optional[List[Image.Image]] = None, # Not fully implemented for batching, focus on T2I + image_shapes: Optional[List[Tuple[int, int]]] = None, + **kwargs ) -> Dict[str, Any]: - output_dict = {'image': None, 'text': None} + + output_dict = {'images': [], 'texts': []} - if image is None and text is None: - print('Please provide at least one input: either an image or text.') + if texts is None and images is None: + print('Please provide at least one input: either texts or images.') return output_dict - input_list = [] - if image is not None: - input_list.append(image) - if text is not None: - input_list.append(text) - - output_list = self.interleave_inference(input_list, **kargs) + # --- This example focuses on Batch Text-to-Image --- + if texts is not None: + batch_size = len(texts) + + # Set default image shapes if not provided + if image_shapes is None: + image_shapes = [(1024, 1024)] * batch_size + + assert batch_size == len(image_shapes), "Number of texts and image_shapes must be equal for batching." + + # 1. Initialize context for the batch + gen_context = self.init_gen_context(batch_size=batch_size) + + with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): + # 2. Update context with the batch of text prompts + gen_context = self.update_context_text_batch(texts, gen_context) + + # 3. Generate images from the context + generated_images = self.batch_gen_image( + image_shapes=image_shapes, + gen_context=gen_context, + **kwargs + ) + output_dict['images'] = generated_images + + # Note: Batched image editing is more complex and would require further refactoring. + # The logic below is for the original single-stream inference. + elif images is not None: + print("Warning: Batched image editing not implemented. Processing first image only.") + # Fallback to original logic for single image + input_list = [images[0]] + if texts: + input_list.append(texts[0]) + output_list = self.interleave_inference(input_list, **kwargs) + for i in output_list: + if isinstance(i, Image.Image): + output_dict['images'].append(i) + elif isinstance(i, str): + output_dict['texts'].append(i) - for i in output_list: - if isinstance(i, Image.Image): - output_dict['image'] = i - elif isinstance(i, str): - output_dict['text'] = i return output_dict From 19b981bbadee170a7eccf4dab559ec541b2ed26c Mon Sep 17 00:00:00 2001 From: Weiyang Jin <137654456+WayneJin0918@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:14:44 +0800 Subject: [PATCH 3/5] support batch infer in jupyter demo --- inference.ipynb | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/inference.ipynb b/inference.ipynb index 0be9bf2d..86cc0bb9 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -301,6 +301,35 @@ "display(output_dict['image'])" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\n", + " \"A majestic lion with a fiery mane, standing on a rocky cliff at sunset.\",\n", + " \"An astronaut playing a guitar on the surface of the moon, with Earth in the background.\",\n", + " \"A futuristic cityscape at night, with flying cars and holographic advertisements, in the style of cyberpunk.\",\n", + " \"A tranquil Japanese zen garden with a koi pond, cherry blossom trees, and a small wooden bridge.\"\n", + "]\n", + "\n", + "# You can specify different shapes per image, or they will all default to (1024, 1024)\n", + "# image_shapes = [(1024, 1024), (1280, 768), (768, 1280), (1024, 1024)]\n", + "\n", + "print(f\"Generating a batch of {len(prompts)} images...\")\n", + "print('-' * 10)\n", + "\n", + "# The __call__ method now accepts a list of texts\n", + "output_dict = inferencer(texts=prompts, **inference_hyper)\n", + "\n", + "# The output contains a list of images\n", + "for i, image in enumerate(output_dict['images']):\n", + " print(f\"Image {i+1} for prompt: '{prompts[i]}'\")\n", + " display(image)\n", + " print('-' * 10)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -349,13 +378,6 @@ "display(output_dict['image'])" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, From d0016cbf1789911e149511d95feca7bef8726db6 Mon Sep 17 00:00:00 2001 From: Weiyang Jin <137654456+WayneJin0918@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:19:23 +0800 Subject: [PATCH 4/5] support batch infer in wise --- eval/gen/gen_images_mp_wise.py | 172 ++++++++++++++++++++------------- scripts/eval/run_wise.sh | 1 + 2 files changed, 106 insertions(+), 67 deletions(-) diff --git a/eval/gen/gen_images_mp_wise.py b/eval/gen/gen_images_mp_wise.py index 2b652712..6ec3c26a 100644 --- a/eval/gen/gen_images_mp_wise.py +++ b/eval/gen/gen_images_mp_wise.py @@ -37,33 +37,36 @@ def move_generation_input_to_device(generation_input, device): return generation_input -def generate_image_with_think( - prompt, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0], cfg_renorm_min=0., timestep_shift=4.0, resolution=1024, +def generate_images_with_think( + prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0], cfg_renorm_min=0., timestep_shift=4.0, resolution=1024, max_length=2048, simple_think=False, device=None ): + batch_size = len(prompts) h, w = resolution, resolution + image_sizes = [(h, w)] * batch_size past_key_values = NaiveCache(model.config.llm_config.num_hidden_layers) - newlens = [0] - new_rope = [0] + newlens = [0] * batch_size + new_rope = [0] * batch_size # system prompt + system_prompts = [SYSTEM_PROMPT] * batch_size generation_input, newlens, new_rope = model.prepare_prompts( curr_kvlens=newlens, curr_rope=new_rope, - prompts=[SYSTEM_PROMPT], + prompts=system_prompts, tokenizer=tokenizer, new_token_ids=new_token_ids, ) generation_input = move_generation_input_to_device(generation_input, device) with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): - past_key_values = model.forward_cache_update_text(past_key_values, **generation_input) + past_key_values = model.forward_cache_update_text(past_key_values, **generation_input) ########## cfg generation_input_cfg = model.prepare_vae_latent_cfg( curr_kvlens=newlens, curr_rope=new_rope, - image_sizes=[(h, w)], + image_sizes=image_sizes, ) generation_input_cfg = move_generation_input_to_device(generation_input_cfg, device) ########## cfg @@ -71,13 +74,13 @@ def generate_image_with_think( generation_input, newlens, new_rope = model.prepare_prompts( curr_kvlens=newlens, curr_rope=new_rope, - prompts=[prompt], + prompts=prompts, tokenizer=tokenizer, new_token_ids=new_token_ids, ) generation_input = move_generation_input_to_device(generation_input, device) with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): - past_key_values = model.forward_cache_update_text(past_key_values, **generation_input) + past_key_values = model.forward_cache_update_text(past_key_values, **generation_input) ########## think tmp_past_key_values = copy.deepcopy(past_key_values) @@ -92,23 +95,35 @@ def generate_image_with_think( end_token_id=new_token_ids['eos_token_id'], **tmp_generation_input, ) - output = tokenizer.decode(unpacked_latent[:,0]) - think_output = output.split('<|im_end|>')[0].split('<|im_start|>')[1] - print("="*30, "original think", "="*30) - print(think_output) - if simple_think: - think_output_list = think_output.split("") - if think_output_list[1] != "": - think_output = think_output_list[1].strip() - print("="*30, "processed think", "="*30) - print(think_output) + raw_outputs = tokenizer.batch_decode(unpacked_latent, skip_special_tokens=False) + original_think_outputs = [] + processed_think_outputs = [] + + for i, raw_output in enumerate(raw_outputs): + try: + think_output = raw_output.split('<|im_end|>')[0].split('<|im_start|>')[1] + except IndexError: + think_output = "" # Fallback for failed generation + original_think_outputs.append(think_output) + + processed_think = think_output + if simple_think: + think_output_list = think_output.split("") + if len(think_output_list) > 1 and think_output_list[1] != "": + processed_think = think_output_list[1].strip() + processed_think_outputs.append(processed_think) + + # print("="*30, "original think", "="*30) + # print(original_think_outputs) + # print("="*30, "processed think", "="*30) + # print(processed_think_outputs) ########## think generation_input, newlens, new_rope = model.prepare_prompts( curr_kvlens=newlens, curr_rope=new_rope, - prompts=[think_output], + prompts=processed_think_outputs, tokenizer=tokenizer, new_token_ids=new_token_ids, ) @@ -119,7 +134,7 @@ def generate_image_with_think( generation_input = model.prepare_vae_latent( curr_kvlens=newlens, curr_rope=new_rope, - image_sizes=[(h, w)], + image_sizes=image_sizes, new_token_ids=new_token_ids, ) generation_input = move_generation_input_to_device(generation_input, device) @@ -142,26 +157,32 @@ def generate_image_with_think( **generation_input, ) - latent0 = unpacked_latent[0] - latent0 = latent0.reshape(1, h//16, w//16, 2, 2, 16) - latent0 = torch.einsum("nhwpqc->nchpwq", latent0) - latent0 = latent0.reshape(1, 16, h//8, w//8) - image = vae_model.decode(latent0.to(device)) - tmpimage = ((image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() - tmpimage = Image.fromarray(tmpimage) + images = [] + for i in range(batch_size): + latent0 = unpacked_latent[i] + latent0 = latent0.reshape(1, h//16, w//16, 2, 2, 16) + latent0 = torch.einsum("nhwpqc->nchpwq", latent0) + latent0 = latent0.reshape(1, 16, h//8, w//8) + image = vae_model.decode(latent0.to(device)) + tmpimage = ((image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() + tmpimage = Image.fromarray(tmpimage) + images.append(tmpimage) - return tmpimage, think_output + return images, original_think_outputs + +def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0], cfg_renorm_min=0., timestep_shift=1.0, resolution=1024, device=None): + batch_size = len(prompts) + image_sizes = [(resolution, resolution)] * batch_size -def generate_image(prompt, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0], cfg_renorm_min=0., timestep_shift=1.0, resolution=1024, device=None): past_key_values = NaiveCache(gen_model.config.llm_config.num_hidden_layers) - newlens = [0] - new_rope = [0] + newlens = [0] * batch_size + new_rope = [0] * batch_size generation_input, newlens, new_rope = gen_model.prepare_prompts( curr_kvlens=newlens, curr_rope=new_rope, - prompts=[prompt], + prompts=prompts, tokenizer=tokenizer, new_token_ids=new_token_ids, ) @@ -174,19 +195,19 @@ def generate_image(prompt, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0 generation_input = gen_model.prepare_vae_latent( curr_kvlens=newlens, curr_rope=new_rope, - image_sizes=[(resolution, resolution)], + image_sizes=image_sizes, new_token_ids=new_token_ids, ) generation_input = move_generation_input_to_device(generation_input, device) cfg_past_key_values = NaiveCache(gen_model.config.llm_config.num_hidden_layers) - cfg_newlens = [0] - cfg_new_rope = [0] + cfg_newlens = [0] * batch_size + cfg_new_rope = [0] * batch_size generation_input_cfg = model.prepare_vae_latent_cfg( curr_kvlens=cfg_newlens, curr_rope=cfg_new_rope, - image_sizes=[(resolution, resolution)], + image_sizes=image_sizes, ) generation_input_cfg = move_generation_input_to_device(generation_input_cfg, device) with torch.no_grad(): @@ -206,15 +227,18 @@ def generate_image(prompt, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0 **generation_input, ) - latent = unpacked_latent[0] - latent = latent.reshape(1, resolution//16, resolution//16, 2, 2, 16) - latent = torch.einsum("nhwpqc->nchpwq", latent) - latent = latent.reshape(1, 16, resolution//8, resolution//8) - image = vae_model.decode(latent.to(device)) - tmpimage = ((image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() - tmpimage = Image.fromarray(tmpimage) + images = [] + for i in range(batch_size): + latent = unpacked_latent[i] + latent = latent.reshape(1, resolution//16, resolution//16, 2, 2, 16) + latent = torch.einsum("nhwpqc->nchpwq", latent) + latent = latent.reshape(1, 16, resolution//8, resolution//8) + image = vae_model.decode(latent.to(device)) + tmpimage = ((image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() + tmpimage = Image.fromarray(tmpimage) + images.append(tmpimage) - return tmpimage + return images if __name__ == "__main__": @@ -223,6 +247,7 @@ def generate_image(prompt, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0 parser.add_argument("--metadata_file", type=str, required=True, help="JSON file containing lines of metadata for each prompt.") parser.add_argument("--cfg_scale", type=float, default=4) parser.add_argument("--resolution", type=int, default=1024) + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference.") parser.add_argument("--max_latent_size", type=int, default=64) parser.add_argument("--think", action="store_true") parser.add_argument('--model-path', type=str, default='hf/BAGEL-7B-MoT/') @@ -303,24 +328,26 @@ def generate_image(prompt, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0 total_metadatas = len(metadatas) prompts_per_gpu = (total_metadatas + world_size - 1) // world_size - start = rank * prompts_per_gpu - end = min(start + prompts_per_gpu, total_metadatas) - print(f"GPU {rank}: Processing {end - start} prompts (indices {start} to {end - 1})") - - for idx in range(start, end): - metadata = metadatas[idx] - prompt = metadata['Prompt'] - prompt_id = metadata['prompt_id'] - outpath = os.path.join(output_dir, f"{prompt_id}.png") - print(f"GPU {rank} processing prompt {idx - start + 1}/{end - start}: '{prompt}'") - - if os.path.exists(outpath): - print(f"GPU {rank} skipping generation for prompt: {prompt}") + start_idx = rank * prompts_per_gpu + end_idx = min(start_idx + prompts_per_gpu, total_metadatas) + print(f"GPU {rank}: Processing {end_idx - start_idx} prompts (indices {start_idx} to {end_idx - 1})") + + for i in range(start_idx, end_idx, args.batch_size): + batch_start = i + batch_end = min(i + args.batch_size, end_idx) + current_batch_size = batch_end - batch_start + if current_batch_size == 0: continue + batch_metadatas = metadatas[batch_start:batch_end] + batch_prompts = [m['Prompt'] for m in batch_metadatas] + output_paths = [os.path.join(output_dir, f"{m['prompt_id']}.png") for m in batch_metadatas] + + print(f"GPU {rank} processing batch of size {current_batch_size}, starting with prompt: '{batch_prompts[0]}'") + if args.think: - tmpimage, think_output = generate_image_with_think( - prompt=prompt, + tmpimages, think_outputs = generate_images_with_think( + prompts=batch_prompts, cfg_scale=cfg_scale, cfg_interval=cfg_interval, cfg_renorm_min=cfg_renorm_min, @@ -331,11 +358,18 @@ def generate_image(prompt, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0 simple_think=False, device=device, ) - with open(outpath.replace(".png", ".txt"), "w") as f: - f.write(think_output) + for j in range(current_batch_size): + if os.path.exists(output_paths[j]): + continue + tmpimage = tmpimages[j] + think_output = think_outputs[j] + tmpimage = tmpimage.crop(tmpimage.getbbox()) + tmpimage.save(output_paths[j]) + with open(output_paths[j].replace(".png", ".txt"), "w") as f: + f.write(think_output) else: - tmpimage = generate_image( - prompt=prompt, + tmpimages = generate_images( + prompts=batch_prompts, cfg_scale=cfg_scale, cfg_interval=cfg_interval, cfg_renorm_min=cfg_renorm_min, @@ -344,9 +378,13 @@ def generate_image(prompt, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0 resolution=args.resolution, device=device, ) + for j in range(current_batch_size): + if os.path.exists(output_paths[j]): + continue + tmpimage = tmpimages[j] + tmpimage = tmpimage.crop(tmpimage.getbbox()) + tmpimage.save(output_paths[j]) - tmpimage = tmpimage.crop(tmpimage.getbbox()) - tmpimage.save(outpath) print(f"GPU {rank} has completed all tasks") - dist.barrier() + dist.barrier() \ No newline at end of file diff --git a/scripts/eval/run_wise.sh b/scripts/eval/run_wise.sh index cdbced1b..90145e87 100644 --- a/scripts/eval/run_wise.sh +++ b/scripts/eval/run_wise.sh @@ -21,6 +21,7 @@ torchrun \ --resolution 1024 \ --max-latent_size 64 \ --model-path $model_path \ + --batch_size $BATCH_SIZE \ --think From 7ea534f35e050df12ef78d0e8c421f6ce45eeb5b Mon Sep 17 00:00:00 2001 From: Weiyang Jin <137654456+WayneJin0918@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:24:58 +0800 Subject: [PATCH 5/5] support bf16/fp16/fp32 --- eval/gen/gen_images_mp_wise.py | 49 ++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/eval/gen/gen_images_mp_wise.py b/eval/gen/gen_images_mp_wise.py index 6ec3c26a..c4d08ca6 100644 --- a/eval/gen/gen_images_mp_wise.py +++ b/eval/gen/gen_images_mp_wise.py @@ -39,7 +39,7 @@ def move_generation_input_to_device(generation_input, device): def generate_images_with_think( prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0], cfg_renorm_min=0., timestep_shift=4.0, resolution=1024, - max_length=2048, simple_think=False, device=None + max_length=2048, simple_think=False, device=None, inference_dtype=torch.float16, autocast_enabled=True ): batch_size = len(prompts) h, w = resolution, resolution @@ -59,7 +59,7 @@ def generate_images_with_think( new_token_ids=new_token_ids, ) generation_input = move_generation_input_to_device(generation_input, device) - with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + with torch.amp.autocast("cuda", enabled=autocast_enabled, dtype=inference_dtype): past_key_values = model.forward_cache_update_text(past_key_values, **generation_input) ########## cfg @@ -79,14 +79,14 @@ def generate_images_with_think( new_token_ids=new_token_ids, ) generation_input = move_generation_input_to_device(generation_input, device) - with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + with torch.amp.autocast("cuda", enabled=autocast_enabled, dtype=inference_dtype): past_key_values = model.forward_cache_update_text(past_key_values, **generation_input) ########## think tmp_past_key_values = copy.deepcopy(past_key_values) tmp_generation_input = model.prepare_start_tokens(newlens, new_rope, new_token_ids) tmp_generation_input = move_generation_input_to_device(tmp_generation_input, device) - with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + with torch.amp.autocast("cuda", enabled=autocast_enabled, dtype=inference_dtype): unpacked_latent = model.generate_text( past_key_values=tmp_past_key_values, max_length=max_length, @@ -113,11 +113,6 @@ def generate_images_with_think( if len(think_output_list) > 1 and think_output_list[1] != "": processed_think = think_output_list[1].strip() processed_think_outputs.append(processed_think) - - # print("="*30, "original think", "="*30) - # print(original_think_outputs) - # print("="*30, "processed think", "="*30) - # print(processed_think_outputs) ########## think generation_input, newlens, new_rope = model.prepare_prompts( @@ -128,7 +123,7 @@ def generate_images_with_think( new_token_ids=new_token_ids, ) generation_input = move_generation_input_to_device(generation_input, device) - with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + with torch.amp.autocast("cuda", enabled=autocast_enabled, dtype=inference_dtype): past_key_values = model.forward_cache_update_text(past_key_values, **generation_input) generation_input = model.prepare_vae_latent( @@ -140,7 +135,7 @@ def generate_images_with_think( generation_input = move_generation_input_to_device(generation_input, device) ########## generate image - with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + with torch.amp.autocast("cuda", enabled=autocast_enabled, dtype=inference_dtype): unpacked_latent = model.generate_image( past_key_values=past_key_values, num_timesteps=num_timesteps, @@ -171,7 +166,7 @@ def generate_images_with_think( return images, original_think_outputs -def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0], cfg_renorm_min=0., timestep_shift=1.0, resolution=1024, device=None): +def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1.0], cfg_renorm_min=0., timestep_shift=1.0, resolution=1024, device=None, inference_dtype=torch.float16, autocast_enabled=True): batch_size = len(prompts) image_sizes = [(resolution, resolution)] * batch_size @@ -189,7 +184,7 @@ def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1 generation_input = move_generation_input_to_device(generation_input, device) with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=True, dtype=torch.float16): + with torch.amp.autocast("cuda", enabled=autocast_enabled, dtype=inference_dtype): past_key_values = gen_model.forward_cache_update_text(past_key_values, **generation_input) generation_input = gen_model.prepare_vae_latent( @@ -211,7 +206,7 @@ def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1 ) generation_input_cfg = move_generation_input_to_device(generation_input_cfg, device) with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + with torch.amp.autocast("cuda", enabled=autocast_enabled, dtype=inference_dtype): unpacked_latent = gen_model.generate_image( past_key_values=past_key_values, num_timesteps=num_timesteps, @@ -251,6 +246,7 @@ def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1 parser.add_argument("--max_latent_size", type=int, default=64) parser.add_argument("--think", action="store_true") parser.add_argument('--model-path', type=str, default='hf/BAGEL-7B-MoT/') + parser.add_argument("--precision", type=str, default="bf16", choices=["auto", "bf16", "fp16", "fp32"], help="Inference precision. 'auto' detects bf16 support automatically.") args = parser.parse_args() seed = 42 @@ -271,6 +267,27 @@ def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1 world_size = dist.get_world_size() device = f"cuda:{rank}" + # Determine the correct inference data type and autocast state + autocast_enabled = True + if args.precision == "auto": + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + inference_dtype = torch.bfloat16 + else: + inference_dtype = torch.float16 + elif args.precision == "bf16": + inference_dtype = torch.bfloat16 + elif args.precision == "fp16": + inference_dtype = torch.float16 + else: # fp32 + inference_dtype = torch.float32 + autocast_enabled = False + + if rank == 0: + if not autocast_enabled: + print("Using fp32 for inference. Mixed precision autocast is disabled.") + else: + print(f"Using {str(inference_dtype).split('.')[-1]} for inference with mixed precision.") + output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) if rank == 0: @@ -357,6 +374,8 @@ def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1 max_length=2048, simple_think=False, device=device, + inference_dtype=inference_dtype, + autocast_enabled=autocast_enabled, ) for j in range(current_batch_size): if os.path.exists(output_paths[j]): @@ -377,6 +396,8 @@ def generate_images(prompts, num_timesteps=50, cfg_scale=4.0, cfg_interval=[0, 1 num_timesteps=num_timesteps, resolution=args.resolution, device=device, + inference_dtype=inference_dtype, + autocast_enabled=autocast_enabled, ) for j in range(current_batch_size): if os.path.exists(output_paths[j]):