Skip to content

Commit c2c2506

Browse files
feat: add inplace file deletion capability
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent 9328845 commit c2c2506

1 file changed

Lines changed: 120 additions & 64 deletions

File tree

scripts/checkpoint_utils.py

Lines changed: 120 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#!/usr/bin/env python3
2-
# Convert FP32 -> BF16 for .pt/.pth, single .safetensors, or a dir of .safetensors.
3-
# Skips optimizer states to preserve FP32.
2+
# Checkpoint utilities:
3+
# - Default: copy INPUT -> OUTPUT unchanged
4+
# - --convert-model-to-bf16: convert model FP32 -> BF16 (optimizer tensors remain FP32)
5+
# - --no-optimizer: copy INPUT -> OUTPUT, dropping optimizer/extra files
6+
# - --prune-inplace: delete optimizer/extra files inside INPUT (destructive)
7+
# - --drop-files: comma-separated names to drop (for --no-optimizer and --prune-inplace)
8+
#
49

510
# Standard
611
from pathlib import Path
@@ -21,6 +26,8 @@
2126

2227
OPTIM_ROOT_KEYS = {"optimizer", "optim", "opt_state"}
2328

29+
DEFAULT_OPTIM_DROPS = {"optimizer.pt", "optimizer", "optimizer_0", "optimizer_1"}
30+
2431

2532
def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
2633
if isinstance(x, torch.Tensor):
@@ -67,117 +74,166 @@ def convert_safetensors_file(inp: Path, out: Path) -> None:
6774
print(f"[safetensors] wrote: {out}")
6875

6976

70-
def slim_copy_dir_skip_only(src: Path, dst: Path, skip_names: Iterable[str]) -> None:
77+
def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
78+
"""Convert all .safetensors in a directory; copy other files as-is."""
79+
dst.mkdir(parents=True, exist_ok=True)
80+
for item in src.iterdir():
81+
if item.suffix == ".safetensors":
82+
convert_safetensors_file(item, dst / item.name)
83+
else:
84+
target = dst / item.name
85+
if item.is_file():
86+
shutil.copy2(item, target)
87+
elif item.is_dir():
88+
shutil.copytree(item, target, dirs_exist_ok=True)
89+
print(f"[dir] wrote: {dst}")
90+
91+
92+
def _name_matches(name: str, patterns: Set[str]) -> bool:
7193
"""
72-
Copy everything from src -> dst EXCEPT files whose names are in skip_names.
73-
Directories are copied entirely unless their name is in skip_names.
94+
Exact-name match. (Simple and predictable.)
95+
If you later want prefix matching, allow 'pat*' and check startswith(pat[:-1]).
7496
"""
75-
dst.mkdir(parents=True, exist_ok=True)
76-
skip_set: Set[str] = set(skip_names)
97+
return name in patterns
7798

99+
100+
def copy_dir_drop(src: Path, dst: Path, drop_names: Iterable[str]) -> None:
101+
"""Copy directory but drop certain files/dirs by exact name."""
102+
dst.mkdir(parents=True, exist_ok=True)
103+
drop_set: Set[str] = set(drop_names)
78104
for item in src.iterdir():
79-
if item.name in skip_set:
105+
if _name_matches(item.name, drop_set):
80106
continue
81107
target = dst / item.name
82108
if item.is_file():
83109
shutil.copy2(item, target)
84110
elif item.is_dir():
85111
shutil.copytree(item, target, dirs_exist_ok=True)
86112
print(
87-
f"[slim] wrote: {dst} (skipped: {', '.join(skip_set) if skip_set else 'none'})"
113+
f"[copy-drop] wrote: {dst} (dropped: {', '.join(sorted(drop_set)) if drop_set else 'none'})"
88114
)
89115

90116

91-
def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
92-
"""Convert all .safetensors in a directory; copy other files as-is."""
93-
dst.mkdir(parents=True, exist_ok=True)
117+
def prune_dir_inplace(src: Path, drop_names: Iterable[str]) -> None:
118+
"""Delete top-level files/dirs in `src` whose names match `drop_names`. Destructive."""
119+
drop_set: Set[str] = set(drop_names)
120+
removed = []
94121
for item in src.iterdir():
95-
if item.suffix == ".safetensors":
96-
convert_safetensors_file(item, dst / item.name)
97-
else:
122+
if _name_matches(item.name, drop_set):
123+
if item.is_file():
124+
item.unlink()
125+
elif item.is_dir():
126+
shutil.rmtree(item)
127+
removed.append(item.name)
128+
print(
129+
f"[prune-inplace] removed: {', '.join(sorted(removed)) if removed else 'nothing'}"
130+
)
131+
132+
133+
def copy_any(src: Path, dst: Path) -> None:
134+
"""Pure copy (no dtype changes, no dropping)."""
135+
if src.is_file():
136+
dst.parent.mkdir(parents=True, exist_ok=True)
137+
shutil.copy2(src, dst if dst.suffix else dst / src.name)
138+
elif src.is_dir():
139+
dst.mkdir(parents=True, exist_ok=True)
140+
for item in src.iterdir():
98141
target = dst / item.name
99142
if item.is_file():
100143
shutil.copy2(item, target)
101144
elif item.is_dir():
102145
shutil.copytree(item, target, dirs_exist_ok=True)
103-
print(f"[dir] wrote: {dst}")
146+
else:
147+
raise SystemExit(f"Not found: {src}")
148+
print(f"[copy] wrote: {dst}")
104149

105150

106151
def main():
107152
ap = argparse.ArgumentParser(
108-
description="Convert FP32 tensors to BF16 (skips optimizer states)."
109-
)
110-
ap.add_argument(
111-
"input",
112-
type=Path,
113-
help="Input: .pt/.pth, .safetensors, or HF directory with .safetensors",
153+
description="Checkpoint utilities: copy by default; optionally convert FP32->BF16\
154+
and/or drop optimizer files."
114155
)
156+
ap.add_argument("input", type=Path, help="Input file or directory")
115157
ap.add_argument("output", type=Path, help="Output file or directory")
116158

117159
ap.add_argument(
118-
"--slim",
160+
"--convert-model-to-bf16",
161+
action="store_true",
162+
help="Convert FP32 -> BF16 for model tensors; optimizer tensors remain FP32.",
163+
)
164+
ap.add_argument(
165+
"--no-optimizer",
119166
action="store_true",
120-
help="For directory inputs: after conversion, copy everything \
121-
except files listed in --skip (default: optimizer.pt).",
167+
help="When writing directory outputs, drop optimizer files/dirs \
168+
(uses defaults + --drop-files).",
122169
)
123170
ap.add_argument(
124-
"--slim-only",
171+
"--prune-inplace",
125172
action="store_true",
126-
help="For directory inputs: DO NOT convert; just copy everything \
127-
except files in --skip.",
173+
help="DELETE matching files/dirs inside INPUT (destructive). Uses defaults + --drop-files.",
128174
)
129175
ap.add_argument(
130-
"--skip",
131-
default="optimizer.pt",
132-
help="Comma-separated file names to skip during slimming \
133-
(applies to --slim or --slim-only). Default: optimizer.pt",
176+
"--drop-files",
177+
default="",
178+
help="Comma-separated file/dir names to drop (for --no-optimizer and --prune-inplace). "
179+
"Merged with defaults: optimizer.pt, optimizer, optimizer_0, optimizer_1",
134180
)
135181

136182
args = ap.parse_args()
137183

138-
if args.slim and args.slim_only:
139-
raise SystemExit("Choose at most one of: --slim or --slim-only.")
184+
if args.prune_inplace and (args.convert_model_to_bf16 or args.no_optimizer):
185+
raise SystemExit(
186+
"--prune-inplace cannot be combined with --convert-model-to-bf16 or --no-optimizer."
187+
)
140188

141-
skip_list = (
142-
[s.strip() for s in args.skip.split(",")]
143-
if (args.slim or args.slim_only)
144-
else []
189+
user_drops = {s.strip() for s in args.drop_files.split(",") if s.strip()}
190+
drops = (
191+
DEFAULT_OPTIM_DROPS | user_drops
192+
if (args.no_optimizer or args.prune_inplace)
193+
else set()
145194
)
146195

147196
p = args.input
148-
if p.is_file():
149-
sfx = p.suffix.lower()
150-
if sfx in {".pt", ".pth"}:
151-
if args.output.is_dir():
152-
raise SystemExit(
153-
"For .pt/.pth input, output must be a file path (not a directory)."
197+
198+
if args.prune_inplace:
199+
if not p.is_dir():
200+
raise SystemExit("--prune-inplace requires a directory INPUT.")
201+
prune_dir_inplace(p, drops)
202+
print("Done.")
203+
return
204+
205+
if args.convert_model_to_bf16:
206+
if p.is_file():
207+
sfx = p.suffix.lower()
208+
if sfx in {".pt", ".pth"}:
209+
convert_pt_pth(p, args.output)
210+
elif sfx == ".safetensors":
211+
out = (
212+
args.output
213+
if args.output.suffix == ".safetensors"
214+
else (args.output / p.name)
154215
)
155-
convert_pt_pth(p, args.output)
156-
elif sfx == ".safetensors":
157-
out = (
158-
args.output
159-
if args.output.suffix == ".safetensors"
160-
else (args.output / p.name)
161-
)
162-
convert_safetensors_file(p, out)
163-
else:
164-
raise SystemExit(f"Unsupported file type: {p}")
165-
elif p.is_dir():
166-
if not any(x.suffix == ".safetensors" for x in p.iterdir()):
167-
raise SystemExit("Directory has no .safetensors files.")
168-
if args.slim_only:
169-
slim_copy_dir_skip_only(p, args.output, skip_list)
170-
else:
216+
convert_safetensors_file(p, out)
217+
else:
218+
raise SystemExit(f"Unsupported file type: {p}")
219+
elif p.is_dir():
220+
if not any(x.suffix == ".safetensors" for x in p.iterdir()):
221+
raise SystemExit("Directory has no .safetensors files.")
171222
convert_dir_of_safetensors(p, args.output)
172-
if args.slim:
173-
tmp = args.output.parent / (args.output.name + "_tmp_slim")
223+
if args.no_optimizer:
224+
tmp = args.output.parent / (args.output.name + "_tmp_drop")
174225
if tmp.exists():
175226
shutil.rmtree(tmp)
176-
slim_copy_dir_skip_only(args.output, tmp, skip_list)
227+
copy_dir_drop(args.output, tmp, drops)
177228
shutil.rmtree(args.output)
178229
tmp.rename(args.output)
230+
else:
231+
raise SystemExit(f"Not found: {p}")
179232
else:
180-
raise SystemExit(f"Not found: {p}")
233+
if p.is_dir() and args.no_optimizer:
234+
copy_dir_drop(p, args.output, drops)
235+
else:
236+
copy_any(p, args.output)
181237

182238
print("Done.")
183239

0 commit comments

Comments
 (0)