Skip to content
This repository was archived by the owner on Mar 3, 2026. It is now read-only.
This repository was archived by the owner on Mar 3, 2026. It is now read-only.

Jobs hang with fsdp only (1d) sharding with >=64 tpu chips #363

@jialei777

Description

@jialei777

Training worked locally with 4 chips fsdp=4: export PJRT_DEVICE=TPU; export TORCHPRIME_TPU_TYPE=v6e-4 && python torchprime/torch_xla_models/train.py model=flex-qwen-1b

MFU: 0.21

On a v5p-128 cluster with command tp run --name jialei-0812-qwen-fsdp32tensor2 torchprime/torch_xla_models/train.py model=flex-qwen-1b task.global_batch_size=64 ici_mesh.fsdp=x ici_mesh.tensor=y

  • fsdp64 tp 1: hang >.<
  • fsdp 32 tp 2: finished MFU 0.22
  • fsdp 16 tp 4: finished: MFU 0.19
  • fsdp 8 tp 8: finished, MFU 0.11

Also see cluster log here https://b.corp.google.com/issues/436664633#comment40.
Runs for llama model: https://b.corp.google.com/issues/436664633#comment33

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions