forked from maybleMyers/chromaforge
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvertpth.py
More file actions
26 lines (20 loc) · 825 Bytes
/
convertpth.py
File metadata and controls
26 lines (20 loc) · 825 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from safetensors.torch import save_file
def convert_pth_to_safetensors(pth_file: str, safetensors_file: str):
"""
Converts a PyTorch .pth file to a .safetensors file.
Args:
pth_file (str): Path to the .pth file to convert.
safetensors_file (str): Path to save the .safetensors file.
Returns:
None
"""
# Load the .pth file
state_dict = torch.load(pth_file, map_location="cpu")
# Ensure the state_dict is a dictionary
if not isinstance(state_dict, dict):
raise ValueError("The .pth file must contain a dictionary-like object.")
# Save the state_dict as a .safetensors file
save_file(state_dict, safetensors_file)
print(f"Converted {pth_file} to {safetensors_file}")
convert_pth_to_safetensors("source.pth", "output.safetensors")