Skip to content

Commit 092a437

Browse files
feat: add ckpt conversion script fp32-bf16
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent f41eb2c commit 092a437

1 file changed

Lines changed: 125 additions & 0 deletions

File tree

scripts/convert_fp32_to_bf16.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#!/usr/bin/env python3
2+
# Convert FP32 -> BF16 for .pt/.pth, single .safetensors, or a dir of .safetensors.
3+
# Skips optimizer states to preserve FP32.
4+
5+
# Standard
6+
from pathlib import Path
7+
from typing import Any
8+
import argparse
9+
import shutil
10+
11+
# Third Party
12+
import torch
13+
14+
try:
15+
# Third Party
16+
from safetensors.torch import safe_open, save_file
17+
18+
HAS_SAFETENSORS = True
19+
except ImportError:
20+
HAS_SAFETENSORS = False
21+
22+
OPTIM_ROOT_KEYS = {"optimizer", "optim", "opt_state"}
23+
24+
25+
def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
26+
if isinstance(x, torch.Tensor):
27+
return x if in_optim or x.dtype != torch.float32 else x.to(torch.bfloat16)
28+
if isinstance(x, dict):
29+
out = {}
30+
for k, v in x.items():
31+
k_lower = k.lower() if isinstance(k, str) else ""
32+
child_in_optim = in_optim or (k_lower in OPTIM_ROOT_KEYS)
33+
out[k] = cast_fp32_to_bf16(v, in_optim=child_in_optim)
34+
return out
35+
if isinstance(x, (list, tuple)):
36+
return type(x)(cast_fp32_to_bf16(v, in_optim=in_optim) for v in x)
37+
return x
38+
39+
40+
def convert_pt_pth(inp: Path, out: Path) -> None:
41+
data = torch.load(inp, map_location="cpu")
42+
data = cast_fp32_to_bf16(data)
43+
out.parent.mkdir(parents=True, exist_ok=True)
44+
torch.save(data, out)
45+
print(f"[pt/pth] wrote: {out}")
46+
47+
48+
def is_optim_tensor_name(name: str) -> bool:
49+
parts = (name or "").lower().replace("/", ".").split(".")
50+
return bool(parts) and parts[0] in OPTIM_ROOT_KEYS
51+
52+
53+
def convert_safetensors_file(inp: Path, out: Path) -> None:
54+
if not HAS_SAFETENSORS:
55+
raise RuntimeError("safetensors not installed. pip install safetensors")
56+
tensors = {}
57+
with safe_open(str(inp), framework="pt", device="cpu") as f:
58+
for key in f.keys():
59+
t = f.get_tensor(key)
60+
if t.dtype == torch.float32 and not is_optim_tensor_name(key):
61+
t = t.to(torch.bfloat16)
62+
tensors[key] = t
63+
out.parent.mkdir(parents=True, exist_ok=True)
64+
save_file(tensors, str(out), metadata={"converted_to": "bfloat16"})
65+
print(f"[safetensors] wrote: {out}")
66+
67+
68+
def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
69+
"""Convert all .safetensors in a directory; copy other files as-is."""
70+
dst.mkdir(parents=True, exist_ok=True)
71+
for item in src.iterdir():
72+
if item.suffix == ".safetensors":
73+
convert_safetensors_file(item, dst / item.name)
74+
else:
75+
target = dst / item.name
76+
if item.is_file():
77+
shutil.copy2(item, target)
78+
elif item.is_dir():
79+
shutil.copytree(item, target, dirs_exist_ok=True)
80+
print(f"[dir] wrote: {dst}")
81+
82+
83+
def main():
84+
ap = argparse.ArgumentParser(
85+
description="Convert FP32 tensors to BF16 (skip optimizer states)."
86+
)
87+
ap.add_argument(
88+
"input",
89+
type=Path,
90+
help="Input: .pt/.pth, .safetensors, or HF directory with .safetensors",
91+
)
92+
ap.add_argument("output", type=Path, help="Output file or directory")
93+
args = ap.parse_args()
94+
95+
p = args.input
96+
if p.is_file():
97+
sfx = p.suffix.lower()
98+
if sfx in {".pt", ".pth"}:
99+
if args.output.is_dir():
100+
raise SystemExit(
101+
"For .pt/.pth input, output must be a file path (not a directory)."
102+
)
103+
convert_pt_pth(p, args.output)
104+
elif sfx == ".safetensors":
105+
out = (
106+
args.output
107+
if args.output.suffix == ".safetensors"
108+
else (args.output / p.name)
109+
)
110+
convert_safetensors_file(p, out)
111+
else:
112+
raise SystemExit(f"Unsupported file type: {p}")
113+
elif p.is_dir():
114+
if any(x.suffix == ".safetensors" for x in p.iterdir()):
115+
convert_dir_of_safetensors(p, args.output)
116+
else:
117+
raise SystemExit("Directory has no .safetensors files.")
118+
else:
119+
raise SystemExit(f"Not found: {p}")
120+
121+
print("Done.")
122+
123+
124+
if __name__ == "__main__":
125+
main()

0 commit comments

Comments
 (0)