|
1 | 1 | #!/usr/bin/env python3 |
2 | | -# Convert FP32 -> BF16 for .pt/.pth, single .safetensors, or a dir of .safetensors. |
3 | | -# Skips optimizer states to preserve FP32. |
| 2 | +# Checkpoint utilities: |
| 3 | +# - Default: copy INPUT -> OUTPUT unchanged |
| 4 | +# - --convert-model-to-bf16: convert model FP32 -> BF16 (optimizer tensors remain FP32) |
| 5 | +# - --no-optimizer: copy INPUT -> OUTPUT, dropping optimizer/extra files |
| 6 | +# - --prune-inplace: delete optimizer/extra files inside INPUT (destructive) |
| 7 | +# - --drop-files: comma-separated names to drop (for --no-optimizer and --prune-inplace) |
| 8 | +# |
4 | 9 |
|
5 | 10 | # Standard |
6 | 11 | from pathlib import Path |
|
21 | 26 |
|
22 | 27 | OPTIM_ROOT_KEYS = {"optimizer", "optim", "opt_state"} |
23 | 28 |
|
| 29 | +DEFAULT_OPTIM_DROPS = {"optimizer.pt", "optimizer", "optimizer_0", "optimizer_1"} |
| 30 | + |
24 | 31 |
|
25 | 32 | def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any: |
26 | 33 | if isinstance(x, torch.Tensor): |
@@ -67,117 +74,166 @@ def convert_safetensors_file(inp: Path, out: Path) -> None: |
67 | 74 | print(f"[safetensors] wrote: {out}") |
68 | 75 |
|
69 | 76 |
|
70 | | -def slim_copy_dir_skip_only(src: Path, dst: Path, skip_names: Iterable[str]) -> None: |
| 77 | +def convert_dir_of_safetensors(src: Path, dst: Path) -> None: |
| 78 | + """Convert all .safetensors in a directory; copy other files as-is.""" |
| 79 | + dst.mkdir(parents=True, exist_ok=True) |
| 80 | + for item in src.iterdir(): |
| 81 | + if item.suffix == ".safetensors": |
| 82 | + convert_safetensors_file(item, dst / item.name) |
| 83 | + else: |
| 84 | + target = dst / item.name |
| 85 | + if item.is_file(): |
| 86 | + shutil.copy2(item, target) |
| 87 | + elif item.is_dir(): |
| 88 | + shutil.copytree(item, target, dirs_exist_ok=True) |
| 89 | + print(f"[dir] wrote: {dst}") |
| 90 | + |
| 91 | + |
| 92 | +def _name_matches(name: str, patterns: Set[str]) -> bool: |
71 | 93 | """ |
72 | | - Copy everything from src -> dst EXCEPT files whose names are in skip_names. |
73 | | - Directories are copied entirely unless their name is in skip_names. |
| 94 | + Exact-name match. (Simple and predictable.) |
| 95 | + If you later want prefix matching, allow 'pat*' and check startswith(pat[:-1]). |
74 | 96 | """ |
75 | | - dst.mkdir(parents=True, exist_ok=True) |
76 | | - skip_set: Set[str] = set(skip_names) |
| 97 | + return name in patterns |
77 | 98 |
|
| 99 | + |
| 100 | +def copy_dir_drop(src: Path, dst: Path, drop_names: Iterable[str]) -> None: |
| 101 | + """Copy directory but drop certain files/dirs by exact name.""" |
| 102 | + dst.mkdir(parents=True, exist_ok=True) |
| 103 | + drop_set: Set[str] = set(drop_names) |
78 | 104 | for item in src.iterdir(): |
79 | | - if item.name in skip_set: |
| 105 | + if _name_matches(item.name, drop_set): |
80 | 106 | continue |
81 | 107 | target = dst / item.name |
82 | 108 | if item.is_file(): |
83 | 109 | shutil.copy2(item, target) |
84 | 110 | elif item.is_dir(): |
85 | 111 | shutil.copytree(item, target, dirs_exist_ok=True) |
86 | 112 | print( |
87 | | - f"[slim] wrote: {dst} (skipped: {', '.join(skip_set) if skip_set else 'none'})" |
| 113 | + f"[copy-drop] wrote: {dst} (dropped: {', '.join(sorted(drop_set)) if drop_set else 'none'})" |
88 | 114 | ) |
89 | 115 |
|
90 | 116 |
|
91 | | -def convert_dir_of_safetensors(src: Path, dst: Path) -> None: |
92 | | - """Convert all .safetensors in a directory; copy other files as-is.""" |
93 | | - dst.mkdir(parents=True, exist_ok=True) |
| 117 | +def prune_dir_inplace(src: Path, drop_names: Iterable[str]) -> None: |
| 118 | + """Delete top-level files/dirs in `src` whose names match `drop_names`. Destructive.""" |
| 119 | + drop_set: Set[str] = set(drop_names) |
| 120 | + removed = [] |
94 | 121 | for item in src.iterdir(): |
95 | | - if item.suffix == ".safetensors": |
96 | | - convert_safetensors_file(item, dst / item.name) |
97 | | - else: |
| 122 | + if _name_matches(item.name, drop_set): |
| 123 | + if item.is_file(): |
| 124 | + item.unlink() |
| 125 | + elif item.is_dir(): |
| 126 | + shutil.rmtree(item) |
| 127 | + removed.append(item.name) |
| 128 | + print( |
| 129 | + f"[prune-inplace] removed: {', '.join(sorted(removed)) if removed else 'nothing'}" |
| 130 | + ) |
| 131 | + |
| 132 | + |
| 133 | +def copy_any(src: Path, dst: Path) -> None: |
| 134 | + """Pure copy (no dtype changes, no dropping).""" |
| 135 | + if src.is_file(): |
| 136 | + dst.parent.mkdir(parents=True, exist_ok=True) |
| 137 | + shutil.copy2(src, dst if dst.suffix else dst / src.name) |
| 138 | + elif src.is_dir(): |
| 139 | + dst.mkdir(parents=True, exist_ok=True) |
| 140 | + for item in src.iterdir(): |
98 | 141 | target = dst / item.name |
99 | 142 | if item.is_file(): |
100 | 143 | shutil.copy2(item, target) |
101 | 144 | elif item.is_dir(): |
102 | 145 | shutil.copytree(item, target, dirs_exist_ok=True) |
103 | | - print(f"[dir] wrote: {dst}") |
| 146 | + else: |
| 147 | + raise SystemExit(f"Not found: {src}") |
| 148 | + print(f"[copy] wrote: {dst}") |
104 | 149 |
|
105 | 150 |
|
106 | 151 | def main(): |
107 | 152 | ap = argparse.ArgumentParser( |
108 | | - description="Convert FP32 tensors to BF16 (skips optimizer states)." |
109 | | - ) |
110 | | - ap.add_argument( |
111 | | - "input", |
112 | | - type=Path, |
113 | | - help="Input: .pt/.pth, .safetensors, or HF directory with .safetensors", |
| 153 | + description="Checkpoint utilities: copy by default; optionally convert FP32->BF16\ |
| 154 | + and/or drop optimizer files." |
114 | 155 | ) |
| 156 | + ap.add_argument("input", type=Path, help="Input file or directory") |
115 | 157 | ap.add_argument("output", type=Path, help="Output file or directory") |
116 | 158 |
|
117 | 159 | ap.add_argument( |
118 | | - "--slim", |
| 160 | + "--convert-model-to-bf16", |
| 161 | + action="store_true", |
| 162 | + help="Convert FP32 -> BF16 for model tensors; optimizer tensors remain FP32.", |
| 163 | + ) |
| 164 | + ap.add_argument( |
| 165 | + "--no-optimizer", |
119 | 166 | action="store_true", |
120 | | - help="For directory inputs: after conversion, copy everything \ |
121 | | - except files listed in --skip (default: optimizer.pt).", |
| 167 | + help="When writing directory outputs, drop optimizer files/dirs \ |
| 168 | + (uses defaults + --drop-files).", |
122 | 169 | ) |
123 | 170 | ap.add_argument( |
124 | | - "--slim-only", |
| 171 | + "--prune-inplace", |
125 | 172 | action="store_true", |
126 | | - help="For directory inputs: DO NOT convert; just copy everything \ |
127 | | - except files in --skip.", |
| 173 | + help="DELETE matching files/dirs inside INPUT (destructive). Uses defaults + --drop-files.", |
128 | 174 | ) |
129 | 175 | ap.add_argument( |
130 | | - "--skip", |
131 | | - default="optimizer.pt", |
132 | | - help="Comma-separated file names to skip during slimming \ |
133 | | - (applies to --slim or --slim-only). Default: optimizer.pt", |
| 176 | + "--drop-files", |
| 177 | + default="", |
| 178 | + help="Comma-separated file/dir names to drop (for --no-optimizer and --prune-inplace). " |
| 179 | + "Merged with defaults: optimizer.pt, optimizer, optimizer_0, optimizer_1", |
134 | 180 | ) |
135 | 181 |
|
136 | 182 | args = ap.parse_args() |
137 | 183 |
|
138 | | - if args.slim and args.slim_only: |
139 | | - raise SystemExit("Choose at most one of: --slim or --slim-only.") |
| 184 | + if args.prune_inplace and (args.convert_model_to_bf16 or args.no_optimizer): |
| 185 | + raise SystemExit( |
| 186 | + "--prune-inplace cannot be combined with --convert-model-to-bf16 or --no-optimizer." |
| 187 | + ) |
140 | 188 |
|
141 | | - skip_list = ( |
142 | | - [s.strip() for s in args.skip.split(",")] |
143 | | - if (args.slim or args.slim_only) |
144 | | - else [] |
| 189 | + user_drops = {s.strip() for s in args.drop_files.split(",") if s.strip()} |
| 190 | + drops = ( |
| 191 | + DEFAULT_OPTIM_DROPS | user_drops |
| 192 | + if (args.no_optimizer or args.prune_inplace) |
| 193 | + else set() |
145 | 194 | ) |
146 | 195 |
|
147 | 196 | p = args.input |
148 | | - if p.is_file(): |
149 | | - sfx = p.suffix.lower() |
150 | | - if sfx in {".pt", ".pth"}: |
151 | | - if args.output.is_dir(): |
152 | | - raise SystemExit( |
153 | | - "For .pt/.pth input, output must be a file path (not a directory)." |
| 197 | + |
| 198 | + if args.prune_inplace: |
| 199 | + if not p.is_dir(): |
| 200 | + raise SystemExit("--prune-inplace requires a directory INPUT.") |
| 201 | + prune_dir_inplace(p, drops) |
| 202 | + print("Done.") |
| 203 | + return |
| 204 | + |
| 205 | + if args.convert_model_to_bf16: |
| 206 | + if p.is_file(): |
| 207 | + sfx = p.suffix.lower() |
| 208 | + if sfx in {".pt", ".pth"}: |
| 209 | + convert_pt_pth(p, args.output) |
| 210 | + elif sfx == ".safetensors": |
| 211 | + out = ( |
| 212 | + args.output |
| 213 | + if args.output.suffix == ".safetensors" |
| 214 | + else (args.output / p.name) |
154 | 215 | ) |
155 | | - convert_pt_pth(p, args.output) |
156 | | - elif sfx == ".safetensors": |
157 | | - out = ( |
158 | | - args.output |
159 | | - if args.output.suffix == ".safetensors" |
160 | | - else (args.output / p.name) |
161 | | - ) |
162 | | - convert_safetensors_file(p, out) |
163 | | - else: |
164 | | - raise SystemExit(f"Unsupported file type: {p}") |
165 | | - elif p.is_dir(): |
166 | | - if not any(x.suffix == ".safetensors" for x in p.iterdir()): |
167 | | - raise SystemExit("Directory has no .safetensors files.") |
168 | | - if args.slim_only: |
169 | | - slim_copy_dir_skip_only(p, args.output, skip_list) |
170 | | - else: |
| 216 | + convert_safetensors_file(p, out) |
| 217 | + else: |
| 218 | + raise SystemExit(f"Unsupported file type: {p}") |
| 219 | + elif p.is_dir(): |
| 220 | + if not any(x.suffix == ".safetensors" for x in p.iterdir()): |
| 221 | + raise SystemExit("Directory has no .safetensors files.") |
171 | 222 | convert_dir_of_safetensors(p, args.output) |
172 | | - if args.slim: |
173 | | - tmp = args.output.parent / (args.output.name + "_tmp_slim") |
| 223 | + if args.no_optimizer: |
| 224 | + tmp = args.output.parent / (args.output.name + "_tmp_drop") |
174 | 225 | if tmp.exists(): |
175 | 226 | shutil.rmtree(tmp) |
176 | | - slim_copy_dir_skip_only(args.output, tmp, skip_list) |
| 227 | + copy_dir_drop(args.output, tmp, drops) |
177 | 228 | shutil.rmtree(args.output) |
178 | 229 | tmp.rename(args.output) |
| 230 | + else: |
| 231 | + raise SystemExit(f"Not found: {p}") |
179 | 232 | else: |
180 | | - raise SystemExit(f"Not found: {p}") |
| 233 | + if p.is_dir() and args.no_optimizer: |
| 234 | + copy_dir_drop(p, args.output, drops) |
| 235 | + else: |
| 236 | + copy_any(p, args.output) |
181 | 237 |
|
182 | 238 | print("Done.") |
183 | 239 |
|
|
0 commit comments