Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

cannot use hf models #1147

@LYMDLUT

Description

@LYMDLUT

$ torchrun --nproc-per-node 4 pippy_llama.py

import os
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.pipelining import ScheduleGPipe, PipelineStage

Grab the model

whole_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct", device_map="meta"
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.distributed.init_process_group(rank=rank, world_size=world_size)

Cut model by equal number of layers per rank

layers_per_rank = whole_model.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")

stage_idx = rank
num_stages = world_size

def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
model = whole_model
if not is_first:
model.model.embed_tokens = None

drop_layers = stop_layer is not None
num_layers = len(model.model.layers) - 1
for idx in range(num_layers, -1, -1):
    if f"layers.{idx}" == stop_layer:
        drop_layers = False
    if f"layers.{idx}" == start_layer:
        drop_layers = True
    if drop_layers:
        del model.model.layers[idx]
# drop_layers = start_layer is not None
# for name in list(model.model.layers.keys()):
#     # we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
#     if f"layers.{name}" == start_layer:
#         drop_layers = False
#     if f"layers.{name}" == stop_layer:
#         drop_layers = True
#     if drop_layers:
#         del model.model.layers[name]

if not is_last:
    model.model.norm = None
    model.lm_head = None

stage = PipelineStage(
    model,
    stage_idx,
    num_stages,
    device,
    #group=pp_mesh.get_group("pp"),
)
return stage, model

base_interval = whole_model.config.num_hidden_layers // num_stages
extra_layers = whole_model.config.num_hidden_layers % num_stages

splits = []
current_layer = 0
for i in range(num_stages - 1):
if i == 0:
current_layer += base_interval
else:
# Middle stages get an extra layer if there are any remaining
if extra_layers > 0:
current_layer += base_interval + 1
extra_layers -= 1
else:
current_layer += base_interval
splits.append("layers." + str(current_layer))

start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
model_chunk.to_empty(device=device)

Run time inputs

full_batch_prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True)["input_ids"].to(device)

schedule = ScheduleGPipe(stage, num_stages)

Run

if rank == 0:
schedule.step(inputs)
elif rank == world_size - 1:

output = schedule.step()
if output is not None:
    next_token_logits = output[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)
    print(tokenizer.batch_decode(next_token))

else:
schedule.step()

16d6183f8d29eb1acc7ba798dc1833d

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions