11#!/usr/bin/env python3
2- # Convert FP32 -> BF16 for .pt/.pth, single .safetensors, or a dir of .safetensors.
3- # Skips optimizer states to preserve FP32.
2+ # Checkpoint utilities (unified --inplace):
3+ # - Default: copy INPUT -> OUTPUT unchanged
4+ # - --convert-model-to-bf16: convert model FP32 -> BF16 (optimizer tensors remain FP32)
5+ # - --no-optimizer: when writing outputs, drop optimizer files/dirs (defaults + --drop-files)
6+ # - --drop-files: comma-separated extra file/dir (used with --no-optimizer, and with --inplace)
7+ # - --inplace: perform conversion and/or dropping directly in INPUT (destructive)
8+
49
510# Standard
611from pathlib import Path
712from typing import Any , Iterable , Set
813import argparse
14+ import os
915import shutil
1016
1117# Third Party
2127
2228OPTIM_ROOT_KEYS = {"optimizer" , "optim" , "opt_state" }
2329
30+ DEFAULT_OPTIM_DROPS = {"optimizer.pt" , "optimizer" , "optimizer_0" , "optimizer_1" }
31+
32+
33+ def _atomic_replace (tmp : Path , dst : Path ) -> None :
34+ dst .parent .mkdir (parents = True , exist_ok = True )
35+ os .replace (str (tmp ), str (dst )) # atomic on POSIX
36+
2437
2538def cast_fp32_to_bf16 (x : Any , * , in_optim : bool = False ) -> Any :
39+ """Recursively cast float32 tensors to bfloat16, skipping optimizer subtrees."""
2640 if isinstance (x , torch .Tensor ):
2741 return x if in_optim or x .dtype != torch .float32 else x .to (torch .bfloat16 )
2842 if isinstance (x , dict ):
@@ -39,6 +53,11 @@ def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
3953 return x
4054
4155
56+ def is_optim_tensor_name (name : str ) -> bool :
57+ first = (name or "" ).lower ().replace ("/" , "." ).split ("." )[0 ]
58+ return any (first .startswith (root ) for root in OPTIM_ROOT_KEYS )
59+
60+
4261def convert_pt_pth (inp : Path , out : Path ) -> None :
4362 data = torch .load (inp , map_location = "cpu" )
4463 data = cast_fp32_to_bf16 (data )
@@ -47,9 +66,11 @@ def convert_pt_pth(inp: Path, out: Path) -> None:
4766 print (f"[pt/pth] wrote: { out } " )
4867
4968
50- def is_optim_tensor_name (name : str ) -> bool :
51- first = (name or "" ).lower ().replace ("/" , "." ).split ("." )[0 ]
52- return any (first .startswith (root ) for root in OPTIM_ROOT_KEYS )
69+ def convert_pt_pth_inplace (inp : Path ) -> None :
70+ tmp = inp .with_suffix (inp .suffix + ".tmp" )
71+ convert_pt_pth (inp , tmp )
72+ _atomic_replace (tmp , inp )
73+ print (f"[pt/pth][inplace] updated: { inp } " )
5374
5475
5576def convert_safetensors_file (inp : Path , out : Path ) -> None :
@@ -67,119 +88,222 @@ def convert_safetensors_file(inp: Path, out: Path) -> None:
6788 print (f"[safetensors] wrote: { out } " )
6889
6990
70- def slim_copy_dir_skip_only (src : Path , dst : Path , skip_names : Iterable [str ]) -> None :
71- """
72- Copy everything from src -> dst EXCEPT files whose names are in skip_names.
73- Directories are copied entirely unless their name is in skip_names.
74- """
91+ def convert_safetensors_file_inplace (inp : Path ) -> None :
92+ tmp = inp .with_suffix (inp .suffix + ".tmp" )
93+ convert_safetensors_file (inp , tmp )
94+ _atomic_replace (tmp , inp )
95+ print (f"[safetensors][inplace] updated: { inp } " )
96+
97+
98+ def convert_dir_of_safetensors (src : Path , dst : Path ) -> None :
99+ """Convert all .safetensors in a directory; copy other files as-is."""
75100 dst .mkdir (parents = True , exist_ok = True )
76- skip_set : Set [str ] = set (skip_names )
101+ for item in src .iterdir ():
102+ if item .suffix == ".safetensors" :
103+ convert_safetensors_file (item , dst / item .name )
104+ else :
105+ target = dst / item .name
106+ if item .is_file ():
107+ shutil .copy2 (item , target )
108+ elif item .is_dir ():
109+ shutil .copytree (item , target , dirs_exist_ok = True )
110+ print (f"[dir] wrote: { dst } " )
111+
77112
113+ def convert_dir_of_safetensors_inplace (src : Path ) -> None :
114+ """Convert all .safetensors files in-place within `src`."""
115+ count = 0
78116 for item in src .iterdir ():
79- if item .name in skip_set :
117+ if item .suffix == ".safetensors" :
118+ convert_safetensors_file_inplace (item )
119+ count += 1
120+ if count == 0 :
121+ raise SystemExit ("Directory has no .safetensors files." )
122+ print (f"[dir][inplace] converted { count } shard(s) in: { src } " )
123+
124+
125+ def _name_matches (name : str , patterns : Set [str ]) -> bool :
126+ """Exact-name match (simple and predictable)."""
127+ return name in patterns
128+
129+
130+ def copy_dir_drop (src : Path , dst : Path , drop_names : Iterable [str ]) -> None :
131+ """Copy directory but drop certain files/dirs by exact name."""
132+ dst .mkdir (parents = True , exist_ok = True )
133+ drop_set : Set [str ] = set (drop_names )
134+ for item in src .iterdir ():
135+ if _name_matches (item .name , drop_set ):
80136 continue
81137 target = dst / item .name
82138 if item .is_file ():
83139 shutil .copy2 (item , target )
84140 elif item .is_dir ():
85141 shutil .copytree (item , target , dirs_exist_ok = True )
86142 print (
87- f"[slim ] wrote: { dst } (skipped : { ', ' .join (skip_set ) if skip_set else 'none' } )"
143+ f"[copy-drop ] wrote: { dst } (dropped : { ', ' .join (sorted ( drop_set )) if drop_set else 'none' } )"
88144 )
89145
90146
91- def convert_dir_of_safetensors (src : Path , dst : Path ) -> None :
92- """Convert all .safetensors in a directory; copy other files as-is."""
93- dst .mkdir (parents = True , exist_ok = True )
147+ def prune_dir_inplace (src : Path , drop_names : Iterable [str ]) -> None :
148+ """Delete top-level files/dirs in `src` whose names match `drop_names`. Destructive."""
149+ drop_set : Set [str ] = set (drop_names )
150+ removed = []
94151 for item in src .iterdir ():
95- if item .suffix == ".safetensors" :
96- convert_safetensors_file (item , dst / item .name )
97- else :
152+ if _name_matches (item .name , drop_set ):
153+ if item .is_file ():
154+ item .unlink ()
155+ elif item .is_dir ():
156+ shutil .rmtree (item )
157+ removed .append (item .name )
158+ print (
159+ f"[inplace-drop] removed: { ', ' .join (sorted (removed )) if removed else 'nothing' } "
160+ )
161+
162+
163+ def copy_any (src : Path , dst : Path ) -> None :
164+ """Pure copy (no dtype changes, no dropping)."""
165+ if src .is_file ():
166+ dst .parent .mkdir (parents = True , exist_ok = True )
167+ shutil .copy2 (src , dst if dst .suffix else dst / src .name )
168+ elif src .is_dir ():
169+ dst .mkdir (parents = True , exist_ok = True )
170+ for item in src .iterdir ():
98171 target = dst / item .name
99172 if item .is_file ():
100173 shutil .copy2 (item , target )
101174 elif item .is_dir ():
102175 shutil .copytree (item , target , dirs_exist_ok = True )
103- print (f"[dir] wrote: { dst } " )
176+ else :
177+ raise SystemExit (f"Not found: { src } " )
178+ print (f"[copy] wrote: { dst } " )
104179
105180
106181def main ():
107182 ap = argparse .ArgumentParser (
108- description = "Convert FP32 tensors to BF16 (skips optimizer states)."
109- )
110- ap .add_argument (
111- "input" ,
112- type = Path ,
113- help = "Input: .pt/.pth, .safetensors, or HF directory with .safetensors" ,
183+ description = "Checkpoint utilities: copy by default; \
184+ optionally convert FP32->BF16 and/or drop optimizer files. "
185+ "Use --inplace to modify INPUT directly."
114186 )
187+ ap .add_argument ("input" , type = Path , help = "Input file or directory" )
115188 ap .add_argument ("output" , type = Path , help = "Output file or directory" )
116189
117190 ap .add_argument (
118- "--slim " ,
191+ "--convert-model-to-bf16 " ,
119192 action = "store_true" ,
120- help = "For directory inputs: after conversion, copy everything \
121- except files listed in --skip (default: optimizer.pt)." ,
193+ help = "Convert FP32 -> BF16 for model tensors; optimizer tensors remain FP32." ,
122194 )
123195 ap .add_argument (
124- "--slim-only " ,
196+ "--no-optimizer " ,
125197 action = "store_true" ,
126- help = "For directory inputs: DO NOT convert; just copy everything \
127- except files in --skip." ,
198+ help = "When writing outputs, drop optimizer files/dirs (defaults + --drop-files)." ,
199+ )
200+ ap .add_argument (
201+ "--drop-files" ,
202+ default = "" ,
203+ help = "Comma-separated extra file/dir names to drop \
204+ (works with --no-optimizer and/or --inplace)." ,
128205 )
129206 ap .add_argument (
130- "--skip " ,
131- default = "optimizer.pt " ,
132- help = "Comma-separated file names to skip during slimming \
133- (applies to --slim or --slim-only). Default: optimizer.pt " ,
207+ "--inplace " ,
208+ action = "store_true " ,
209+ help = "Perform operations directly on INPUT (destructive). For files: overwrite in place; "
210+ "for directories: convert shards in-place and/ or delete dropped names. " ,
134211 )
135212
136213 args = ap .parse_args ()
137214
138- if args .slim and args .slim_only :
139- raise SystemExit ("Choose at most one of: --slim or --slim-only." )
215+ p = args .input
140216
141- skip_list = (
142- [s .strip () for s in args .skip .split ("," )]
143- if (args .slim or args .slim_only )
144- else []
145- )
217+ user_drops = {s .strip () for s in args .drop_files .split ("," ) if s .strip ()}
218+ if args .no_optimizer :
219+ drop_set = DEFAULT_OPTIM_DROPS | user_drops
220+ else :
221+ drop_set = user_drops
222+
223+ if args .inplace :
224+ if not p .exists ():
225+ raise SystemExit (f"Not found: { p } " )
226+
227+ if args .convert_model_to_bf16 :
228+ if p .is_file ():
229+ sfx = p .suffix .lower ()
230+ if sfx in {".pt" , ".pth" }:
231+ convert_pt_pth_inplace (p )
232+ elif sfx == ".safetensors" :
233+ convert_safetensors_file_inplace (p )
234+ else :
235+ raise SystemExit (
236+ f"Unsupported file type for inplace conversion: { p } "
237+ )
238+ elif p .is_dir ():
239+ convert_dir_of_safetensors_inplace (p )
240+ else :
241+ raise SystemExit (f"Not found: { p } " )
242+
243+ if drop_set :
244+ if not p .is_dir ():
245+ print (
246+ "[inplace] --drop-files applies to directories; skipping for file input."
247+ )
248+ else :
249+ prune_dir_inplace (p , drop_set )
250+
251+ print ("Done." )
252+ return
253+
254+ if not args .convert_model_to_bf16 and not args .no_optimizer and not drop_set :
255+ copy_any (p , args .output )
256+ print ("Done." )
257+ return
146258
147- p = args .input
148259 if p .is_file ():
149260 sfx = p .suffix .lower ()
150- if sfx in {".pt" , ".pth" }:
151- if args .output .is_dir ():
152- raise SystemExit (
153- "For .pt/.pth input, output must be a file path (not a directory)."
261+ if args .convert_model_to_bf16 :
262+ if sfx in {".pt" , ".pth" }:
263+ convert_pt_pth (p , args .output )
264+ elif sfx == ".safetensors" :
265+ out = (
266+ args .output
267+ if args .output .suffix == ".safetensors"
268+ else (args .output / p .name )
154269 )
155- convert_pt_pth (p , args .output )
156- elif sfx == ".safetensors" :
157- out = (
158- args .output
159- if args .output .suffix == ".safetensors"
160- else (args .output / p .name )
161- )
162- convert_safetensors_file (p , out )
163- else :
164- raise SystemExit (f"Unsupported file type: { p } " )
165- elif p .is_dir ():
166- if not any (x .suffix == ".safetensors" for x in p .iterdir ()):
167- raise SystemExit ("Directory has no .safetensors files." )
168- if args .slim_only :
169- slim_copy_dir_skip_only (p , args .output , skip_list )
270+ convert_safetensors_file (p , out )
271+ else :
272+ raise SystemExit (f"Unsupported file type: { p } " )
170273 else :
274+ copy_any (p , args .output )
275+ print ("Done." )
276+ return
277+
278+ if p .is_dir ():
279+ if args .convert_model_to_bf16 :
280+ if not any (x .suffix == ".safetensors" for x in p .iterdir ()):
281+ raise SystemExit ("Directory has no .safetensors files." )
171282 convert_dir_of_safetensors (p , args .output )
172- if args .slim :
173- tmp = args .output .parent / (args .output .name + "_tmp_slim " )
283+ if args .no_optimizer or drop_set :
284+ tmp = args .output .parent / (args .output .name + "_tmp_drop " )
174285 if tmp .exists ():
175286 shutil .rmtree (tmp )
176- slim_copy_dir_skip_only (args .output , tmp , skip_list )
287+ copy_dir_drop (
288+ args .output ,
289+ tmp ,
290+ DEFAULT_OPTIM_DROPS | drop_set if args .no_optimizer else drop_set ,
291+ )
177292 shutil .rmtree (args .output )
178293 tmp .rename (args .output )
179- else :
180- raise SystemExit (f"Not found: { p } " )
294+ else :
295+ if args .no_optimizer or drop_set :
296+ copy_dir_drop (
297+ p ,
298+ args .output ,
299+ DEFAULT_OPTIM_DROPS | drop_set if args .no_optimizer else drop_set ,
300+ )
301+ else :
302+ copy_any (p , args .output )
303+ print ("Done." )
304+ return
181305
182- print ( "Done. " )
306+ raise SystemExit ( f"Not found: { p } " )
183307
184308
185309if __name__ == "__main__" :
0 commit comments