-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
363 lines (295 loc) · 12.1 KB
/
train_model.py
File metadata and controls
363 lines (295 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""
Train a DistilBERT model for ad detection using transcripts and ad segments.
Exports model to Core ML format for iPhone deployment.
"""
import json
import os
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
DistilBertForSequenceClassification,
DistilBertTokenizer,
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
from datasets import Dataset as HFDataset
import config
class AdDataset(Dataset):
"""Dataset class for ad detection training."""
def __init__(self, texts, labels, tokenizer, max_length=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
def load_metadata():
"""Load metadata.json if it exists."""
if config.METADATA_FILE.exists():
try:
with open(config.METADATA_FILE, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error loading metadata: {e}")
return {"episodes": {}}
def prepare_training_data():
"""Prepare training data from transcripts and ad segments."""
metadata = load_metadata()
texts = []
labels = []
print("Preparing training data...")
for episode_key, episode_data in metadata.get('episodes', {}).items():
if not episode_data.get('ad_detected', False):
continue
transcript_path = config.BASE_DIR / episode_data.get('transcript_path', '')
ad_path = config.BASE_DIR / episode_data.get('ad_path', '')
if not transcript_path.exists():
continue
# Load transcript
with open(transcript_path, 'r', encoding='utf-8') as f:
transcript_data = json.load(f)
# Load ad segments
ad_segments = []
if ad_path and ad_path.exists():
with open(ad_path, 'r', encoding='utf-8') as f:
ad_data = json.load(f)
ad_segments = ad_data.get('ad_segments', [])
# Create time-based mapping of ad segments
ad_ranges = []
for ad in ad_segments:
ad_ranges.append({
'start': ad.get('start_time', 0),
'end': ad.get('end_time', 0),
'text': ad.get('text', '')
})
# Extract segments from transcript
segments = transcript_data.get('segments', [])
if not segments:
# If no segments, use full text as one segment
full_text = transcript_data.get('text', '')
if full_text:
segments = [{'start': 0, 'end': 0, 'text': full_text}]
# Label each segment as ad or content
for segment in segments:
segment_start = segment.get('start', 0)
segment_end = segment.get('end', segment_start + 5) # Default 5 seconds if no end
segment_text = segment.get('text', '').strip()
if not segment_text or len(segment_text) < 10:
continue
# Check if this segment overlaps with any ad
is_ad = False
for ad_range in ad_ranges:
# Check for overlap
if (segment_start < ad_range['end'] and segment_end > ad_range['start']):
is_ad = True
break
# Use sliding windows for longer segments
if len(segment_text) > 512:
# Split into overlapping windows
words = segment_text.split()
window_size = 400 # ~512 tokens
overlap = 100
for i in range(0, len(words), window_size - overlap):
window_text = ' '.join(words[i:i+window_size])
if len(window_text) > 20:
texts.append(window_text)
labels.append(1 if is_ad else 0)
else:
texts.append(segment_text)
labels.append(1 if is_ad else 0)
print(f"Total samples: {len(texts)}")
print(f"Ads: {sum(labels)}, Content: {len(labels) - sum(labels)}")
return texts, labels
def balance_dataset(texts, labels):
"""Balance dataset by downsampling majority class."""
texts = np.array(texts, dtype=object)
labels = np.array(labels)
ad_indices = np.where(labels == 1)[0]
content_indices = np.where(labels == 0)[0]
# Downsample majority class
min_class_size = min(len(ad_indices), len(content_indices))
if len(ad_indices) > min_class_size:
np.random.seed(42)
selected_ad_indices = np.random.choice(ad_indices, min_class_size, replace=False)
else:
selected_ad_indices = ad_indices
if len(content_indices) > min_class_size:
np.random.seed(42)
selected_content_indices = np.random.choice(content_indices, min_class_size, replace=False)
else:
selected_content_indices = content_indices
balanced_indices = np.concatenate([selected_ad_indices, selected_content_indices])
np.random.shuffle(balanced_indices)
return texts[balanced_indices].tolist(), labels[balanced_indices].tolist()
def train_model():
"""Train DistilBERT model for ad detection."""
# Prepare data
texts, labels = prepare_training_data()
if len(texts) == 0:
print("Error: No training data available. Run download, transcription, and ad detection first.")
return
# Balance dataset
texts, labels = balance_dataset(texts, labels)
print(f"\nBalanced dataset:")
print(f"Total samples: {len(texts)}")
print(f"Ads: {sum(labels)}, Content: {len(labels) - sum(labels)}")
# Split into train/test
X_train, X_test, y_train, y_test = train_test_split(
texts, labels, test_size=0.2, random_state=42, stratify=labels
)
print(f"\nTrain set: {len(X_train)} samples")
print(f"Test set: {len(X_test)} samples")
# Load tokenizer and model
model_name = "distilbert-base-multilingual-cased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(
model_name,
num_labels=2
)
# Create datasets
train_dataset = HFDataset.from_dict({
'text': X_train,
'label': y_train
})
test_dataset = HFDataset.from_dict({
'text': X_test,
'label': y_test
})
def tokenize_function(examples):
return tokenizer(
examples['text'],
truncation=True,
padding='max_length',
max_length=config.MAX_SEQUENCE_LENGTH
)
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
# Training arguments
training_args = TrainingArguments(
output_dir=str(config.MODEL_DIR / "checkpoints"),
num_train_epochs=config.NUM_EPOCHS,
per_device_train_batch_size=config.BATCH_SIZE,
per_device_eval_batch_size=config.BATCH_SIZE,
learning_rate=config.LEARNING_RATE,
weight_decay=0.01,
logging_dir=str(config.MODEL_DIR / "logs"),
logging_steps=50,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
save_total_limit=2,
seed=42
)
# Compute metrics function
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
precision, recall, f1, _ = precision_recall_fscore_support(
labels, predictions, average='binary', zero_division=0
)
accuracy = accuracy_score(labels, predictions)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1
}
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)
# Train
print("\n=== Training Model ===")
trainer.train()
# Evaluate
print("\n=== Evaluating Model ===")
eval_results = trainer.evaluate()
print(f"Test Results: {eval_results}")
# Final evaluation on test set
predictions = trainer.predict(test_dataset)
pred_labels = np.argmax(predictions.predictions, axis=1)
print("\n=== Detailed Classification Report ===")
print(classification_report(y_test, pred_labels, target_names=['Content', 'Ad']))
print("\n=== Confusion Matrix ===")
print(confusion_matrix(y_test, pred_labels))
# Save model
final_model_path = config.MODEL_DIR / "final_model"
trainer.save_model(str(final_model_path))
tokenizer.save_pretrained(str(final_model_path))
print(f"\n✓ Model saved to: {final_model_path}")
# Export to Core ML
export_to_coreml(final_model_path)
def export_to_coreml(model_path):
"""Export trained model to Core ML format for iPhone deployment."""
try:
import coremltools as ct
from transformers import pipeline
print("\n=== Exporting to Core ML ===")
# Load model and tokenizer
tokenizer = DistilBertTokenizer.from_pretrained(str(model_path))
model = DistilBertForSequenceClassification.from_pretrained(str(model_path))
# Create text classification pipeline
classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
# Convert to Core ML
# Note: This is a simplified conversion. For production, you may need
# to convert the model architecture more carefully
coreml_model = ct.convert(
classifier.model,
inputs=[
ct.TensorType(name="input_ids", shape=(1, config.MAX_SEQUENCE_LENGTH), dtype=np.int64),
ct.TensorType(name="attention_mask", shape=(1, config.MAX_SEQUENCE_LENGTH), dtype=np.int64)
],
outputs=[
ct.TensorType(name="logits", dtype=np.float32)
]
)
# Add metadata
coreml_model.author = "CleanCast ML Model"
coreml_model.short_description = "DistilBERT-based podcast ad detector"
coreml_model.version = "1.0"
# Save Core ML model
coreml_path = config.MODEL_DIR / "ad_detector.mlpackage"
coreml_model.save(str(coreml_path))
print(f"✓ Core ML model saved to: {coreml_path}")
print(" Note: You may need to fine-tune the Core ML export for your specific use case.")
except Exception as e:
print(f"Warning: Could not export to Core ML: {e}")
print(" Model saved in PyTorch format. You can convert manually later.")
print(" Consider using coremltools or onnx for conversion.")
if __name__ == "__main__":
train_model()