Skip to content

saxml/tools/convert_llama_ckpt.py casts weights to float16, losing precision #28

@houeland

Description

@houeland

The weights for e.g. Meta-Llama-3.1-70B-Instruct are distributed in bfloat16 format. When converting the weights, the saxml script first casts the weights to float16, which is lossy.

E.g. for Meta-Llama-3.1-70B-Instruct:

>>> example = torch.load('consolidated.01.pth', weights_only=True, map_location=torch.device('cpu'), mmap=True)['layers.79.feed_forward.w1.weight'][100][5685]
>>> example
tensor(-4.2617e-06, dtype=torch.bfloat16)
>>> example.type(torch.float16)
tensor(-4.2915e-06, dtype=torch.float16)

(This sounds similar to an issue HuggingFace had with weight conversion: huggingface/transformers#25446, which was acknowledged to degrade performance and was fixed.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions