From 5de4b5ec5ca2cde847cd73fdf66defb6a1a59813 Mon Sep 17 00:00:00 2001 From: dawnranger Date: Mon, 16 Mar 2026 16:27:02 +0800 Subject: [PATCH] optimize hidden_state generation by avoid redundant CPU/CPU data transfer --- .../models/target/target_model_wrapper.py | 16 +++++------ angelslim/engine.py | 2 +- tools/generate_hidden_for_draft_model.py | 27 ++++++++++++++----- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py index 11d17151..72492388 100644 --- a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py +++ b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py @@ -369,9 +369,9 @@ def get_hidden_states_and_logits( def hook(module, args, kwargs): if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None: - inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu()) + inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach()) if "position_ids" in kwargs and kwargs["position_ids"] is not None: - position_ids_list.append(kwargs["position_ids"].clone().detach().cpu()) + position_ids_list.append(kwargs["position_ids"].clone().detach()) return args, kwargs if self.target_model_type == "qwen3_vl": @@ -439,9 +439,9 @@ def get_aux_and_target_hiddens( def hook(module, args, kwargs): if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None: - inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu()) + inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach()) if "position_ids" in kwargs and kwargs["position_ids"] is not None: - position_ids_list.append(kwargs["position_ids"].clone().detach().cpu()) + position_ids_list.append(kwargs["position_ids"].clone().detach()) return args, kwargs if self.target_model_type == "qwen3_vl": @@ -571,9 +571,9 @@ def get_hidden_states_and_logits( def hook(module, args, kwargs): if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None: - inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu()) + inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach()) if "position_ids" in kwargs and kwargs["position_ids"] is not None: - position_ids_list.append(kwargs["position_ids"].clone().detach().cpu()) + position_ids_list.append(kwargs["position_ids"].clone().detach()) return args, kwargs handle = self.model.language_model.register_forward_pre_hook(hook, with_kwargs=True) @@ -627,9 +627,9 @@ def get_aux_and_target_hiddens( def hook(module, args, kwargs): if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None: - inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu()) + inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach()) if "position_ids" in kwargs and kwargs["position_ids"] is not None: - position_ids_list.append(kwargs["position_ids"].clone().detach().cpu()) + position_ids_list.append(kwargs["position_ids"].clone().detach()) return args, kwargs handle = self.model.language_model.register_forward_pre_hook(hook, with_kwargs=True) diff --git a/angelslim/engine.py b/angelslim/engine.py index 12b97ea6..b4002a8f 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -532,7 +532,7 @@ def run( print_info("=" * 80) for i, output in enumerate(outputs[:5]): generated_text = output.outputs[0].text - print_info(f"[{i+1}] Output: {generated_text!r}") + print_info(f"[{i + 1}] Output: {generated_text!r}") print_info(f"\nTotal outputs generated: {len(outputs)}") # Collect and save statistics diff --git a/tools/generate_hidden_for_draft_model.py b/tools/generate_hidden_for_draft_model.py index 150b7233..7c5be74d 100644 --- a/tools/generate_hidden_for_draft_model.py +++ b/tools/generate_hidden_for_draft_model.py @@ -431,6 +431,11 @@ def main(): """Main execution function.""" # Setup distributed environment rank, world_size, local_rank = setup_distributed() + logger.info( + f"Distributed environment initialized: pid: {os.getpid()}, rank {rank}," + "world_size {world_size}, local_rank {local_rank}", + extra={"rank": rank}, + ) # Parse arguments args = parse_arguments() @@ -452,7 +457,8 @@ def main(): f"Target model loaded: {args.target_model_name_or_path or args.model_name}", extra={"rank": rank}, ) - logger.info(f"tokenizer: {target_model.tokenizer}") + if rank == 0: + logger.info(f"tokenizer: {target_model.tokenizer}", extra={"rank": 0}) # Load dataset dataset = load_dataset(args, target_model.tokenizer, rank) @@ -469,9 +475,22 @@ def main(): generator = HiddenStateGenerator(target_model, output_dir, rank=rank) successful, failed = generator.generate(dataset_slice) + logger.info( + f"Rank {rank} - Successful: {successful}, Failed: {failed}", + extra={"rank": rank}, + ) + + except Exception as e: + logger.error(f"Rank {rank} encountered error: {e}", extra={"rank": rank}) + + finally: # Synchronize all processes if world_size > 1: + logger.info( + f"Rank {rank} reached barrier, waiting for other ranks...", extra={"rank": rank} + ) dist.barrier() + logger.info(f"Rank {rank} passed barrier.", extra={"rank": rank}) # Log final statistics (only on rank 0) if rank == 0: @@ -483,12 +502,6 @@ def main(): ) logger.info("=" * 50, extra={"rank": rank}) - logger.info( - f"Rank {rank} - Successful: {successful}, Failed: {failed}", - extra={"rank": rank}, - ) - - finally: # Cleanup distributed environment cleanup_distributed()