Skip to content

Commit ea8ac4c

Browse files
feat: add inplace file deletion capability
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent 9328845 commit ea8ac4c

1 file changed

Lines changed: 193 additions & 69 deletions

File tree

scripts/checkpoint_utils.py

Lines changed: 193 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
#!/usr/bin/env python3
2-
# Convert FP32 -> BF16 for .pt/.pth, single .safetensors, or a dir of .safetensors.
3-
# Skips optimizer states to preserve FP32.
2+
# Checkpoint utilities (unified --inplace):
3+
# - Default: copy INPUT -> OUTPUT unchanged
4+
# - --convert-model-to-bf16: convert model FP32 -> BF16 (optimizer tensors remain FP32)
5+
# - --no-optimizer: when writing outputs, drop optimizer files/dirs (defaults + --drop-files)
6+
# - --drop-files: comma-separated extra file/dir (used with --no-optimizer, and with --inplace)
7+
# - --inplace: perform conversion and/or dropping directly in INPUT (destructive)
8+
49

510
# Standard
611
from pathlib import Path
712
from typing import Any, Iterable, Set
813
import argparse
14+
import os
915
import shutil
1016

1117
# Third Party
@@ -21,8 +27,16 @@
2127

2228
OPTIM_ROOT_KEYS = {"optimizer", "optim", "opt_state"}
2329

30+
DEFAULT_OPTIM_DROPS = {"optimizer.pt", "optimizer", "optimizer_0", "optimizer_1"}
31+
32+
33+
def _atomic_replace(tmp: Path, dst: Path) -> None:
34+
dst.parent.mkdir(parents=True, exist_ok=True)
35+
os.replace(str(tmp), str(dst)) # atomic on POSIX
36+
2437

2538
def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
39+
"""Recursively cast float32 tensors to bfloat16, skipping optimizer subtrees."""
2640
if isinstance(x, torch.Tensor):
2741
return x if in_optim or x.dtype != torch.float32 else x.to(torch.bfloat16)
2842
if isinstance(x, dict):
@@ -39,6 +53,11 @@ def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
3953
return x
4054

4155

56+
def is_optim_tensor_name(name: str) -> bool:
57+
first = (name or "").lower().replace("/", ".").split(".")[0]
58+
return any(first.startswith(root) for root in OPTIM_ROOT_KEYS)
59+
60+
4261
def convert_pt_pth(inp: Path, out: Path) -> None:
4362
data = torch.load(inp, map_location="cpu")
4463
data = cast_fp32_to_bf16(data)
@@ -47,9 +66,11 @@ def convert_pt_pth(inp: Path, out: Path) -> None:
4766
print(f"[pt/pth] wrote: {out}")
4867

4968

50-
def is_optim_tensor_name(name: str) -> bool:
51-
first = (name or "").lower().replace("/", ".").split(".")[0]
52-
return any(first.startswith(root) for root in OPTIM_ROOT_KEYS)
69+
def convert_pt_pth_inplace(inp: Path) -> None:
70+
tmp = inp.with_suffix(inp.suffix + ".tmp")
71+
convert_pt_pth(inp, tmp)
72+
_atomic_replace(tmp, inp)
73+
print(f"[pt/pth][inplace] updated: {inp}")
5374

5475

5576
def convert_safetensors_file(inp: Path, out: Path) -> None:
@@ -67,119 +88,222 @@ def convert_safetensors_file(inp: Path, out: Path) -> None:
6788
print(f"[safetensors] wrote: {out}")
6889

6990

70-
def slim_copy_dir_skip_only(src: Path, dst: Path, skip_names: Iterable[str]) -> None:
71-
"""
72-
Copy everything from src -> dst EXCEPT files whose names are in skip_names.
73-
Directories are copied entirely unless their name is in skip_names.
74-
"""
91+
def convert_safetensors_file_inplace(inp: Path) -> None:
92+
tmp = inp.with_suffix(inp.suffix + ".tmp")
93+
convert_safetensors_file(inp, tmp)
94+
_atomic_replace(tmp, inp)
95+
print(f"[safetensors][inplace] updated: {inp}")
96+
97+
98+
def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
99+
"""Convert all .safetensors in a directory; copy other files as-is."""
75100
dst.mkdir(parents=True, exist_ok=True)
76-
skip_set: Set[str] = set(skip_names)
101+
for item in src.iterdir():
102+
if item.suffix == ".safetensors":
103+
convert_safetensors_file(item, dst / item.name)
104+
else:
105+
target = dst / item.name
106+
if item.is_file():
107+
shutil.copy2(item, target)
108+
elif item.is_dir():
109+
shutil.copytree(item, target, dirs_exist_ok=True)
110+
print(f"[dir] wrote: {dst}")
111+
77112

113+
def convert_dir_of_safetensors_inplace(src: Path) -> None:
114+
"""Convert all .safetensors files in-place within `src`."""
115+
count = 0
78116
for item in src.iterdir():
79-
if item.name in skip_set:
117+
if item.suffix == ".safetensors":
118+
convert_safetensors_file_inplace(item)
119+
count += 1
120+
if count == 0:
121+
raise SystemExit("Directory has no .safetensors files.")
122+
print(f"[dir][inplace] converted {count} shard(s) in: {src}")
123+
124+
125+
def _name_matches(name: str, patterns: Set[str]) -> bool:
126+
"""Exact-name match (simple and predictable)."""
127+
return name in patterns
128+
129+
130+
def copy_dir_drop(src: Path, dst: Path, drop_names: Iterable[str]) -> None:
131+
"""Copy directory but drop certain files/dirs by exact name."""
132+
dst.mkdir(parents=True, exist_ok=True)
133+
drop_set: Set[str] = set(drop_names)
134+
for item in src.iterdir():
135+
if _name_matches(item.name, drop_set):
80136
continue
81137
target = dst / item.name
82138
if item.is_file():
83139
shutil.copy2(item, target)
84140
elif item.is_dir():
85141
shutil.copytree(item, target, dirs_exist_ok=True)
86142
print(
87-
f"[slim] wrote: {dst} (skipped: {', '.join(skip_set) if skip_set else 'none'})"
143+
f"[copy-drop] wrote: {dst} (dropped: {', '.join(sorted(drop_set)) if drop_set else 'none'})"
88144
)
89145

90146

91-
def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
92-
"""Convert all .safetensors in a directory; copy other files as-is."""
93-
dst.mkdir(parents=True, exist_ok=True)
147+
def prune_dir_inplace(src: Path, drop_names: Iterable[str]) -> None:
148+
"""Delete top-level files/dirs in `src` whose names match `drop_names`. Destructive."""
149+
drop_set: Set[str] = set(drop_names)
150+
removed = []
94151
for item in src.iterdir():
95-
if item.suffix == ".safetensors":
96-
convert_safetensors_file(item, dst / item.name)
97-
else:
152+
if _name_matches(item.name, drop_set):
153+
if item.is_file():
154+
item.unlink()
155+
elif item.is_dir():
156+
shutil.rmtree(item)
157+
removed.append(item.name)
158+
print(
159+
f"[inplace-drop] removed: {', '.join(sorted(removed)) if removed else 'nothing'}"
160+
)
161+
162+
163+
def copy_any(src: Path, dst: Path) -> None:
164+
"""Pure copy (no dtype changes, no dropping)."""
165+
if src.is_file():
166+
dst.parent.mkdir(parents=True, exist_ok=True)
167+
shutil.copy2(src, dst if dst.suffix else dst / src.name)
168+
elif src.is_dir():
169+
dst.mkdir(parents=True, exist_ok=True)
170+
for item in src.iterdir():
98171
target = dst / item.name
99172
if item.is_file():
100173
shutil.copy2(item, target)
101174
elif item.is_dir():
102175
shutil.copytree(item, target, dirs_exist_ok=True)
103-
print(f"[dir] wrote: {dst}")
176+
else:
177+
raise SystemExit(f"Not found: {src}")
178+
print(f"[copy] wrote: {dst}")
104179

105180

106181
def main():
107182
ap = argparse.ArgumentParser(
108-
description="Convert FP32 tensors to BF16 (skips optimizer states)."
109-
)
110-
ap.add_argument(
111-
"input",
112-
type=Path,
113-
help="Input: .pt/.pth, .safetensors, or HF directory with .safetensors",
183+
description="Checkpoint utilities: copy by default; \
184+
optionally convert FP32->BF16 and/or drop optimizer files. "
185+
"Use --inplace to modify INPUT directly."
114186
)
187+
ap.add_argument("input", type=Path, help="Input file or directory")
115188
ap.add_argument("output", type=Path, help="Output file or directory")
116189

117190
ap.add_argument(
118-
"--slim",
191+
"--convert-model-to-bf16",
119192
action="store_true",
120-
help="For directory inputs: after conversion, copy everything \
121-
except files listed in --skip (default: optimizer.pt).",
193+
help="Convert FP32 -> BF16 for model tensors; optimizer tensors remain FP32.",
122194
)
123195
ap.add_argument(
124-
"--slim-only",
196+
"--no-optimizer",
125197
action="store_true",
126-
help="For directory inputs: DO NOT convert; just copy everything \
127-
except files in --skip.",
198+
help="When writing outputs, drop optimizer files/dirs (defaults + --drop-files).",
199+
)
200+
ap.add_argument(
201+
"--drop-files",
202+
default="",
203+
help="Comma-separated extra file/dir names to drop \
204+
(works with --no-optimizer and/or --inplace).",
128205
)
129206
ap.add_argument(
130-
"--skip",
131-
default="optimizer.pt",
132-
help="Comma-separated file names to skip during slimming \
133-
(applies to --slim or --slim-only). Default: optimizer.pt",
207+
"--inplace",
208+
action="store_true",
209+
help="Perform operations directly on INPUT (destructive). For files: overwrite in place; "
210+
"for directories: convert shards in-place and/or delete dropped names.",
134211
)
135212

136213
args = ap.parse_args()
137214

138-
if args.slim and args.slim_only:
139-
raise SystemExit("Choose at most one of: --slim or --slim-only.")
215+
p = args.input
140216

141-
skip_list = (
142-
[s.strip() for s in args.skip.split(",")]
143-
if (args.slim or args.slim_only)
144-
else []
145-
)
217+
user_drops = {s.strip() for s in args.drop_files.split(",") if s.strip()}
218+
if args.no_optimizer:
219+
drop_set = DEFAULT_OPTIM_DROPS | user_drops
220+
else:
221+
drop_set = user_drops
222+
223+
if args.inplace:
224+
if not p.exists():
225+
raise SystemExit(f"Not found: {p}")
226+
227+
if args.convert_model_to_bf16:
228+
if p.is_file():
229+
sfx = p.suffix.lower()
230+
if sfx in {".pt", ".pth"}:
231+
convert_pt_pth_inplace(p)
232+
elif sfx == ".safetensors":
233+
convert_safetensors_file_inplace(p)
234+
else:
235+
raise SystemExit(
236+
f"Unsupported file type for inplace conversion: {p}"
237+
)
238+
elif p.is_dir():
239+
convert_dir_of_safetensors_inplace(p)
240+
else:
241+
raise SystemExit(f"Not found: {p}")
242+
243+
if drop_set:
244+
if not p.is_dir():
245+
print(
246+
"[inplace] --drop-files applies to directories; skipping for file input."
247+
)
248+
else:
249+
prune_dir_inplace(p, drop_set)
250+
251+
print("Done.")
252+
return
253+
254+
if not args.convert_model_to_bf16 and not args.no_optimizer and not drop_set:
255+
copy_any(p, args.output)
256+
print("Done.")
257+
return
146258

147-
p = args.input
148259
if p.is_file():
149260
sfx = p.suffix.lower()
150-
if sfx in {".pt", ".pth"}:
151-
if args.output.is_dir():
152-
raise SystemExit(
153-
"For .pt/.pth input, output must be a file path (not a directory)."
261+
if args.convert_model_to_bf16:
262+
if sfx in {".pt", ".pth"}:
263+
convert_pt_pth(p, args.output)
264+
elif sfx == ".safetensors":
265+
out = (
266+
args.output
267+
if args.output.suffix == ".safetensors"
268+
else (args.output / p.name)
154269
)
155-
convert_pt_pth(p, args.output)
156-
elif sfx == ".safetensors":
157-
out = (
158-
args.output
159-
if args.output.suffix == ".safetensors"
160-
else (args.output / p.name)
161-
)
162-
convert_safetensors_file(p, out)
163-
else:
164-
raise SystemExit(f"Unsupported file type: {p}")
165-
elif p.is_dir():
166-
if not any(x.suffix == ".safetensors" for x in p.iterdir()):
167-
raise SystemExit("Directory has no .safetensors files.")
168-
if args.slim_only:
169-
slim_copy_dir_skip_only(p, args.output, skip_list)
270+
convert_safetensors_file(p, out)
271+
else:
272+
raise SystemExit(f"Unsupported file type: {p}")
170273
else:
274+
copy_any(p, args.output)
275+
print("Done.")
276+
return
277+
278+
if p.is_dir():
279+
if args.convert_model_to_bf16:
280+
if not any(x.suffix == ".safetensors" for x in p.iterdir()):
281+
raise SystemExit("Directory has no .safetensors files.")
171282
convert_dir_of_safetensors(p, args.output)
172-
if args.slim:
173-
tmp = args.output.parent / (args.output.name + "_tmp_slim")
283+
if args.no_optimizer or drop_set:
284+
tmp = args.output.parent / (args.output.name + "_tmp_drop")
174285
if tmp.exists():
175286
shutil.rmtree(tmp)
176-
slim_copy_dir_skip_only(args.output, tmp, skip_list)
287+
copy_dir_drop(
288+
args.output,
289+
tmp,
290+
DEFAULT_OPTIM_DROPS | drop_set if args.no_optimizer else drop_set,
291+
)
177292
shutil.rmtree(args.output)
178293
tmp.rename(args.output)
179-
else:
180-
raise SystemExit(f"Not found: {p}")
294+
else:
295+
if args.no_optimizer or drop_set:
296+
copy_dir_drop(
297+
p,
298+
args.output,
299+
DEFAULT_OPTIM_DROPS | drop_set if args.no_optimizer else drop_set,
300+
)
301+
else:
302+
copy_any(p, args.output)
303+
print("Done.")
304+
return
181305

182-
print("Done.")
306+
raise SystemExit(f"Not found: {p}")
183307

184308

185309
if __name__ == "__main__":

0 commit comments

Comments
 (0)