Skip to content

Commit 9db7033

Browse files
committed
Add Vision Model Upload Functionality and Enhance Training Workflow
- Introduced new upload_vision.py script for comprehensive model upload capabilities - Extended train_vision.py with additional upload methods for Hugging Face and Ollama - Added support for saving merged models, pushing GGUF models, and creating Ollama models - Implemented flexible upload targeting with configurable options - Enhanced model preparation and deployment workflow for vision models
1 parent 308c69a commit 9db7033

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

praisonai/train_vision.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,45 @@ def save_model_gguf(self):
240240
quantization_method="q4_k_m"
241241
)
242242

243+
def prepare_modelfile_content(self):
244+
output_model = self.config["hf_model_name"]
245+
246+
template = '''{{- range $index, $_ := .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
247+
248+
{{ .Content }}
249+
{{- if gt (len (slice $.Messages $index)) 1 }}<|eot_id|>
250+
{{- else if ne .Role "assistant" }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
251+
252+
{{ end }}
253+
{{- end }}'''
254+
255+
return f"""FROM {output_model}
256+
TEMPLATE {template}
257+
PARAMETER temperature 0.6
258+
PARAMETER top_p 0.9
259+
"""
260+
261+
def create_and_push_ollama_model(self):
262+
modelfile_content = self.prepare_modelfile_content()
263+
with open("Modelfile", "w") as file:
264+
file.write(modelfile_content)
265+
subprocess.run(["ollama", "serve"])
266+
subprocess.run(["ollama", "create", f"{self.config['ollama_model']}:{self.config['model_parameters']}", "-f", "Modelfile"])
267+
subprocess.run(["ollama", "push", f"{self.config['ollama_model']}:{self.config['model_parameters']}"])
268+
243269
def run(self):
244270
self.print_system_info()
245271
self.check_gpu()
246272
self.check_ram()
247273
if self.config.get("train", "true").lower() == "true":
248274
self.prepare_model()
249275
self.train_model()
276+
if self.config.get("huggingface_save", "true").lower() == "true":
277+
self.save_model_merged()
278+
if self.config.get("huggingface_save_gguf", "true").lower() == "true":
279+
self.push_model_gguf()
280+
if self.config.get("ollama_save", "true").lower() == "true":
281+
self.create_and_push_ollama_model()
250282

251283

252284
def main():

praisonai/upload_vision.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
This script handles uploading trained vision models to Hugging Face and Ollama.
5+
It reads configuration from config.yaml and provides options to upload in different formats.
6+
"""
7+
8+
import os
9+
import yaml
10+
import torch
11+
import shutil
12+
import subprocess
13+
from unsloth import FastVisionModel
14+
15+
class UploadVisionModel:
16+
def __init__(self, config_path="config.yaml"):
17+
self.load_config(config_path)
18+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19+
self.model = None
20+
self.hf_tokenizer = None
21+
22+
def load_config(self, path):
23+
"""Load configuration from yaml file."""
24+
with open(path, "r") as file:
25+
self.config = yaml.safe_load(file)
26+
print("DEBUG: Loaded config:", self.config)
27+
28+
def prepare_model(self):
29+
"""Load the trained model for uploading."""
30+
print("DEBUG: Loading trained model and tokenizer...")
31+
self.model, original_tokenizer = FastVisionModel.from_pretrained(
32+
model_name=self.config.get("output_dir", "lora_model"),
33+
load_in_4bit=self.config.get("load_in_4bit", True)
34+
)
35+
self.hf_tokenizer = original_tokenizer
36+
print("DEBUG: Model and tokenizer loaded successfully.")
37+
38+
def save_model_merged(self):
39+
"""Save merged model to Hugging Face Hub."""
40+
print(f"DEBUG: Saving merged model to Hugging Face Hub: {self.config['hf_model_name']}")
41+
if os.path.exists(self.config["hf_model_name"]):
42+
shutil.rmtree(self.config["hf_model_name"])
43+
self.model.push_to_hub_merged(
44+
self.config["hf_model_name"],
45+
self.hf_tokenizer,
46+
save_method="merged_16bit",
47+
token=os.getenv("HF_TOKEN")
48+
)
49+
print("DEBUG: Model saved to Hugging Face Hub successfully.")
50+
51+
def push_model_gguf(self):
52+
"""Push model in GGUF format to Hugging Face Hub."""
53+
print(f"DEBUG: Pushing GGUF model to Hugging Face Hub: {self.config['hf_model_name']}")
54+
self.model.push_to_hub_gguf(
55+
self.config["hf_model_name"],
56+
self.hf_tokenizer,
57+
quantization_method=self.config.get("quantization_method", "q4_k_m"),
58+
token=os.getenv("HF_TOKEN")
59+
)
60+
print("DEBUG: GGUF model pushed to Hugging Face Hub successfully.")
61+
62+
def prepare_modelfile_content(self):
63+
"""Prepare Ollama modelfile content using Llama 3.2 vision template."""
64+
output_model = self.config["hf_model_name"]
65+
66+
# Using Llama 3.2 vision template format
67+
template = """{{- range $index, $_ := .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
68+
69+
{{ .Content }}
70+
{{- if gt (len (slice $.Messages $index)) 1 }}<|eot_id|>
71+
{{- else if ne .Role "assistant" }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
72+
73+
{{ end }}
74+
{{- end }}"""
75+
76+
# Assemble the modelfile content with Llama 3.2 vision parameters
77+
modelfile = f"FROM {output_model}\n"
78+
modelfile += "TEMPLATE \"""" + template + "\"""\n"
79+
modelfile += "PARAMETER temperature 0.6\n"
80+
modelfile += "PARAMETER top_p 0.9\n"
81+
return modelfile
82+
83+
def create_and_push_ollama_model(self):
84+
"""Create and push model to Ollama."""
85+
print(f"DEBUG: Creating Ollama model: {self.config['ollama_model']}:{self.config['model_parameters']}")
86+
modelfile_content = self.prepare_modelfile_content()
87+
with open("Modelfile", "w") as file:
88+
file.write(modelfile_content)
89+
90+
print("DEBUG: Starting Ollama server...")
91+
subprocess.run(["ollama", "serve"])
92+
93+
print("DEBUG: Creating Ollama model...")
94+
subprocess.run([
95+
"ollama", "create",
96+
f"{self.config['ollama_model']}:{self.config['model_parameters']}",
97+
"-f", "Modelfile"
98+
])
99+
100+
print("DEBUG: Pushing model to Ollama...")
101+
subprocess.run([
102+
"ollama", "push",
103+
f"{self.config['ollama_model']}:{self.config['model_parameters']}"
104+
])
105+
print("DEBUG: Model pushed to Ollama successfully.")
106+
107+
def upload(self, target="all"):
108+
"""
109+
Upload the model to specified targets.
110+
Args:
111+
target (str): One of 'all', 'huggingface', 'huggingface_gguf', or 'ollama'
112+
"""
113+
self.prepare_model()
114+
115+
if target in ["all", "huggingface"]:
116+
self.save_model_merged()
117+
118+
if target in ["all", "huggingface_gguf"]:
119+
self.push_model_gguf()
120+
121+
if target in ["all", "ollama"]:
122+
self.create_and_push_ollama_model()
123+
124+
def main():
125+
import argparse
126+
parser = argparse.ArgumentParser(description="Upload Vision Model to Various Platforms")
127+
parser.add_argument("--config", default="config.yaml", help="Path to configuration file")
128+
parser.add_argument(
129+
"--target",
130+
choices=["all", "huggingface", "huggingface_gguf", "ollama"],
131+
default="all",
132+
help="Target platform to upload to"
133+
)
134+
args = parser.parse_args()
135+
136+
uploader = UploadVisionModel(config_path=args.config)
137+
uploader.upload(target=args.target)
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)