-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvis2.py
More file actions
312 lines (245 loc) · 11.7 KB
/
vis2.py
File metadata and controls
312 lines (245 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from collections import Counter
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import logging
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("dataset_visualization.log"),
logging.StreamHandler()
]
)
class MovieGenreVisualizer:
"""Class for visualizing the preprocessed movie genre dataset"""
def __init__(self, data_dir, output_dir):
"""Initialize the visualizer"""
self.data_dir = data_dir
self.output_dir = output_dir
self.train_df = None
self.val_df = None
self.test_df = None
self.genre_config = None
self.top_genres = None
self.genre_columns = None
# Create output directory
os.makedirs(output_dir, exist_ok=True)
def load_data(self):
"""Load the preprocessed data and configuration"""
logging.info("Loading data...")
# Load splits
self.train_df = pd.read_pickle(os.path.join(self.data_dir, "splits", "train.pkl"))
self.val_df = pd.read_pickle(os.path.join(self.data_dir, "splits", "val.pkl"))
self.test_df = pd.read_pickle(os.path.join(self.data_dir, "splits", "test.pkl"))
# Load genre configuration
with open(os.path.join(self.data_dir, "genre_config.json"), "r") as f:
self.genre_config = json.load(f)
self.top_genres = self.genre_config["top_genres"]
self.genre_columns = [f"genre_{genre}" for genre in self.top_genres]
logging.info(f"Loaded {len(self.train_df)} training, {len(self.val_df)} validation, and {len(self.test_df)} test samples")
logging.info(f"Working with {len(self.top_genres)} genres: {', '.join(self.top_genres)}")
return True
def plot_dataset_stats(self):
"""Plot basic dataset statistics"""
logging.info("Plotting dataset statistics...")
# Create figure with 4 subplots
fig, axes = plt.subplots(2, 2, figsize=(16, 14))
fig.suptitle("Movie Dataset Statistics", fontsize=16)
# Plot 1: Genre distribution across splits
ax1 = axes[0, 0]
genre_counts = {
'train': [],
'val': [],
'test': []
}
for genre in self.top_genres:
col = f"genre_{genre}"
genre_counts['train'].append(self.train_df[col].sum())
genre_counts['val'].append(self.val_df[col].sum())
genre_counts['test'].append(self.test_df[col].sum())
x = np.arange(len(self.top_genres))
width = 0.25
ax1.bar(x - width, genre_counts['train'], width, label='Train')
ax1.bar(x, genre_counts['val'], width, label='Val')
ax1.bar(x + width, genre_counts['test'], width, label='Test')
ax1.set_xticks(x)
ax1.set_xticklabels(self.top_genres, rotation=45, ha='right')
ax1.set_title('Genre Distribution Across Splits')
ax1.set_ylabel('Count')
ax1.legend()
# Plot 2: Label cardinality (genres per movie)
ax2 = axes[0, 1]
train_cardinality = self.train_df[self.genre_columns].sum(axis=1)
val_cardinality = self.val_df[self.genre_columns].sum(axis=1)
test_cardinality = self.test_df[self.genre_columns].sum(axis=1)
sns.histplot(train_cardinality, bins=range(1, 11), kde=False, label='Train', alpha=0.7, ax=ax2)
sns.histplot(val_cardinality, bins=range(1, 11), kde=False, label='Val', alpha=0.7, ax=ax2)
sns.histplot(test_cardinality, bins=range(1, 11), kde=False, label='Test', alpha=0.7, ax=ax2)
ax2.set_title('Number of Genres per Movie')
ax2.set_xlabel('Number of Genres')
ax2.set_ylabel('Count')
ax2.set_xticks(range(1, max(train_cardinality.max(), val_cardinality.max(), test_cardinality.max()) + 1))
ax2.legend()
# Plot 3: Overview length distribution
ax3 = axes[1, 0]
sns.histplot(self.train_df['overview_length'], bins=20, kde=True, ax=ax3)
ax3.set_title('Overview Length Distribution')
ax3.set_xlabel('Overview Length (characters)')
ax3.set_ylabel('Count')
# Add vertical line for mean length
mean_length = self.train_df['overview_length'].mean()
ax3.axvline(mean_length, color='red', linestyle='--',
label=f'Mean: {mean_length:.0f} chars')
ax3.legend()
# Plot 4: Genre co-occurrence matrix (heatmap)
ax4 = axes[1, 1]
# Calculate co-occurrence matrix
co_matrix = np.zeros((len(self.top_genres), len(self.top_genres)))
for i, genre1 in enumerate(self.top_genres):
for j, genre2 in enumerate(self.top_genres):
if i == j:
# Count of movies with this genre
co_matrix[i, j] = self.train_df[f"genre_{genre1}"].sum()
else:
# Count of movies with both genres
co_matrix[i, j] = ((self.train_df[f"genre_{genre1}"] == 1) &
(self.train_df[f"genre_{genre2}"] == 1)).sum()
# Normalize by diagonal (convert to percentages)
normalized_matrix = np.zeros_like(co_matrix)
for i in range(len(self.top_genres)):
for j in range(len(self.top_genres)):
if co_matrix[i, i] > 0:
normalized_matrix[i, j] = 100 * co_matrix[i, j] / co_matrix[i, i]
# Plot heatmap
sns.heatmap(normalized_matrix, annot=True, fmt=".1f", cmap="YlGnBu",
xticklabels=self.top_genres, yticklabels=self.top_genres, ax=ax4)
ax4.set_title('Genre Co-occurrence (% of Row Genre)')
plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.show() # Changed from plt.savefig()
logging.info("Dataset statistics plotted")
def plot_class_weights(self):
"""Plot the class weights from the genre configuration"""
logging.info("Plotting class weights...")
genre_weights = self.genre_config["genre_weights"]
# Convert to DataFrame and sort
weights_df = pd.DataFrame.from_dict(
genre_weights, orient='index', columns=['weight']
).sort_values('weight', ascending=False)
plt.figure(figsize=(10, 6))
sns.barplot(x=weights_df.index, y='weight', data=weights_df)
plt.title('Class Weights for Loss Function')
plt.xlabel('Genre')
plt.ylabel('Weight')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show() # Changed from plt.savefig()
logging.info("Class weights plotted")
def plot_genre_combinations(self):
"""Plot the most common genre combinations"""
logging.info("Plotting common genre combinations...")
# Function to get combination string from binary genres
def get_combination(row):
genres = []
for genre, col in zip(self.top_genres, self.genre_columns):
if row[col] == 1:
genres.append(genre)
return ", ".join(sorted(genres))
# Get combinations
self.train_df['genre_combination'] = self.train_df.apply(get_combination, axis=1)
# Count combinations
combo_counts = self.train_df['genre_combination'].value_counts()
# Plot top combinations
plt.figure(figsize=(12, 8))
combo_df = pd.DataFrame({'Combination': combo_counts.index, 'Count': combo_counts.values})
sns.barplot(x='Count', y='Combination', data=combo_df.head(15))
plt.title('Top 15 Genre Combinations')
plt.tight_layout()
plt.show() # Changed from plt.savefig()
logging.info("Genre combinations plotted")
def visualize_overview_lengths_by_genre(self):
"""Visualize overview lengths by genre"""
logging.info("Visualizing overview lengths by genre...")
# Create DataFrame for plotting
genre_lengths = []
for genre, col in zip(self.top_genres, self.genre_columns):
# Get overview lengths for movies with this genre
lengths = self.train_df.loc[self.train_df[col] == 1, 'overview_length']
for length in lengths:
genre_lengths.append({'Genre': genre, 'Overview Length': length})
genre_lengths_df = pd.DataFrame(genre_lengths)
# Plot
plt.figure(figsize=(12, 8))
sns.boxplot(x='Genre', y='Overview Length', data=genre_lengths_df)
plt.title('Overview Length by Genre')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show() # Changed from plt.savefig()
logging.info("Overview lengths by genre visualized")
def visualize_tsne_clusters(self):
"""Visualize data clusters using t-SNE on genre combinations"""
logging.info("Visualizing t-SNE clusters...")
# Get genre matrix
genre_matrix = self.train_df[self.genre_columns].values
# Create color mapping based on primary genre
primary_genres = []
for i, row in enumerate(genre_matrix):
# Get the index of the first genre (alphabetically among those present)
present_genres = [self.top_genres[j] for j in range(len(row)) if row[j] == 1]
if present_genres:
primary_genres.append(sorted(present_genres)[0])
else:
primary_genres.append('Unknown')
# Perform t-SNE
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
tsne_result = tsne.fit_transform(genre_matrix)
# Plot
plt.figure(figsize=(12, 10))
# Get unique primary genres and colors
unique_genres = sorted(set(primary_genres))
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_genres)))
# Plot each genre
for i, genre in enumerate(unique_genres):
mask = [pg == genre for pg in primary_genres]
plt.scatter(
tsne_result[mask, 0],
tsne_result[mask, 1],
c=[colors[i]],
label=genre,
alpha=0.6
)
plt.title('t-SNE Visualization of Movies by Genre')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show() # Changed from plt.savefig()
logging.info("t-SNE visualization complete")
def run_all_visualizations(self):
"""Run all visualization functions"""
logging.info("Running all visualizations...")
self.load_data()
self.plot_dataset_stats()
self.plot_class_weights()
self.plot_genre_combinations()
self.visualize_overview_lengths_by_genre()
self.visualize_tsne_clusters()
logging.info("All visualizations complete!")
def main():
"""Main function to run visualizations"""
# Get script directory
script_dir = os.path.dirname(os.path.abspath(__file__))
# Set paths
data_dir = os.path.join(script_dir, "data", "processed", "genre_balanced")
output_dir = os.path.join(script_dir, "visualizations", "genre_balanced")
# Create visualizer
visualizer = MovieGenreVisualizer(data_dir, output_dir)
# Run all visualizations
visualizer.run_all_visualizations()
if __name__ == "__main__":
main()