Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,9 @@ def get_hidden_states_and_logits(

def hook(module, args, kwargs):
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu())
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach())
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
position_ids_list.append(kwargs["position_ids"].clone().detach())
return args, kwargs

if self.target_model_type == "qwen3_vl":
Expand Down Expand Up @@ -439,9 +439,9 @@ def get_aux_and_target_hiddens(

def hook(module, args, kwargs):
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu())
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach())
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
position_ids_list.append(kwargs["position_ids"].clone().detach())
return args, kwargs

if self.target_model_type == "qwen3_vl":
Expand Down Expand Up @@ -571,9 +571,9 @@ def get_hidden_states_and_logits(

def hook(module, args, kwargs):
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu())
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach())
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
position_ids_list.append(kwargs["position_ids"].clone().detach())
return args, kwargs

handle = self.model.language_model.register_forward_pre_hook(hook, with_kwargs=True)
Expand Down Expand Up @@ -627,9 +627,9 @@ def get_aux_and_target_hiddens(

def hook(module, args, kwargs):
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu())
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach())
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
position_ids_list.append(kwargs["position_ids"].clone().detach())
return args, kwargs

handle = self.model.language_model.register_forward_pre_hook(hook, with_kwargs=True)
Expand Down
2 changes: 1 addition & 1 deletion angelslim/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def run(
print_info("=" * 80)
for i, output in enumerate(outputs[:5]):
generated_text = output.outputs[0].text
print_info(f"[{i+1}] Output: {generated_text!r}")
print_info(f"[{i + 1}] Output: {generated_text!r}")
print_info(f"\nTotal outputs generated: {len(outputs)}")

# Collect and save statistics
Expand Down
27 changes: 20 additions & 7 deletions tools/generate_hidden_for_draft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ def main():
"""Main execution function."""
# Setup distributed environment
rank, world_size, local_rank = setup_distributed()
logger.info(
f"Distributed environment initialized: pid: {os.getpid()}, rank {rank},"
"world_size {world_size}, local_rank {local_rank}",
extra={"rank": rank},
)

# Parse arguments
args = parse_arguments()
Expand All @@ -452,7 +457,8 @@ def main():
f"Target model loaded: {args.target_model_name_or_path or args.model_name}",
extra={"rank": rank},
)
logger.info(f"tokenizer: {target_model.tokenizer}")
if rank == 0:
logger.info(f"tokenizer: {target_model.tokenizer}", extra={"rank": 0})

# Load dataset
dataset = load_dataset(args, target_model.tokenizer, rank)
Expand All @@ -469,9 +475,22 @@ def main():
generator = HiddenStateGenerator(target_model, output_dir, rank=rank)
successful, failed = generator.generate(dataset_slice)

logger.info(
f"Rank {rank} - Successful: {successful}, Failed: {failed}",
extra={"rank": rank},
)

except Exception as e:
logger.error(f"Rank {rank} encountered error: {e}", extra={"rank": rank})

finally:
# Synchronize all processes
if world_size > 1:
logger.info(
f"Rank {rank} reached barrier, waiting for other ranks...", extra={"rank": rank}
)
dist.barrier()
logger.info(f"Rank {rank} passed barrier.", extra={"rank": rank})

# Log final statistics (only on rank 0)
if rank == 0:
Expand All @@ -483,12 +502,6 @@ def main():
)
logger.info("=" * 50, extra={"rank": rank})

logger.info(
f"Rank {rank} - Successful: {successful}, Failed: {failed}",
extra={"rank": rank},
)

finally:
# Cleanup distributed environment
cleanup_distributed()

Expand Down
Loading