-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclip_train.py
More file actions
120 lines (93 loc) · 4.3 KB
/
clip_train.py
File metadata and controls
120 lines (93 loc) · 4.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from transformers import CLIPModel, CLIPProcessor, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, Dataset, DatasetDict
import torch
import evaluate # Hugging Face's metric library
import numpy as np
from data.utils import remove_unused_columns
import torch
import numpy as np
def compute_iou_metric(predictions, labels):
# Define a list of thresholds to evaluate
thresholds = np.arange(0.0, 1.05, 0.05).tolist() # 0.0 to 1.0 in 0.05 increments
best_iou = 0.0 # To track the best IoU score
best_threshold = 0.0 # To track the best threshold
# Loop through each threshold and calculate IoU
for threshold in thresholds:
# Convert predictions to binary based on the threshold
pred_binary = (predictions > threshold).float() # Adjust threshold
# Calculate intersection and union
intersection = torch.logical_and(pred_binary, labels).sum(1).float() # Sum across labels
union = torch.logical_or(pred_binary, labels).sum(1).float()
# Avoid division by zero
union[union == 0] = 1
# IoU for each sample
iou = (intersection / union).mean().item() # Average over the entire batch
# If this IoU is the best, update the best_iou and best_threshold
if iou > best_iou:
best_iou = iou
best_threshold = threshold
# Return the best threshold and corresponding IoU
return {
"best_threshold": best_threshold,
"best_iou": best_iou,
}
class CustomCLIPDataset(Dataset):
def __init__(self, data, processor):
self.data = data
self.processor = processor
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
pixel_values = self.data[idx]["pixel_values"]
labels = self.data[idx]["labels"] # Or another appropriate field
# Process the image data
inputs = self.processor(images=pixel_values, return_tensors="pt")
inputs["labels"] = torch.tensor(labels, dtype=torch.float) # Assuming multi-label classification
return inputs
def main():
# Load the pre-trained CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
cache_path = '/home/donghee/huggingface_data_cache/reciep1m'
datasets = DatasetDict.load_from_dist(cache_path)
datasets = remove_unused_columns(datasets, generate_mode=False)
# Prepare the dataset for training
train_dataset = CustomCLIPDataset(datasets['train'], clip_processor)
eval_dataset = CustomCLIPDataset(datasets['eval'], clip_processor)
test_dataset = CustomCLIPDataset(datasets['test'], clip_processor)
training_args = TrainingArguments(
output_dir="./outputs/CLIP", # Directory to save model checkpoints and outputs
num_train_epochs=10, # Number of training epochs
per_device_train_batch_size=64, # Batch size
evaluation_strategy="steps", # When to evaluate during training
save_steps=500, # How often to save checkpoints
logging_steps=100, # How often to log training information
# learning_rate=5e-5, # Learning rate for the optimizer
fp16=True, # Enable mixed precision (if supported)
load_best_model_at_end=True, # Load best model based on validation
metric_for_best_model="best_iou", # Metric to determine best model
greater_is_better=True,
save_total_limit=3,
ddp_find_unused_parameters=False,
save_safetensors=False,
eval_accumulation_steps=4
)
# Create the `Trainer` instance
trainer = Trainer(
model=clip_model,
args=training_args,
train_dataset=train_dataset,
compute_metrics = compute_iou_metric,
eval_dataset=eval_dataset,
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
# Additional arguments like `eval_dataset`, `compute_metrics`, etc.
)
# Start training
trainer.train()
trainer.model.save_pretrained('outputs/CLIP/best')
test_result = trainer.evaluate(test_dataset)
print(test_result)
print("** DONE **")
if __name__ == '__main__':
main()
# ViT를 train 시킨 거를 retriever로 넣어서 모델에 합쳐..? 그럼 너무 커지지 않나..