Skip to content

Commit 08b2652

Browse files
committed
ran ruff formatter again
1 parent 24aa859 commit 08b2652

3 files changed

Lines changed: 35 additions & 9 deletions

File tree

examples/m_ap_and_top_k_accuracy.ipynb

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,14 @@
8080
}
8181
},
8282
"outputs": [],
83-
"source": "train_imgs, train_labels = zip(\n *[(img, label) for img, label, _ in train_dataset], strict=True\n)\nval_imgs, val_labels = zip(*[(img, label) for img, label, _ in val_dataset], strict=True)"
83+
"source": [
84+
"train_imgs, train_labels = zip(\n",
85+
" *[(img, label) for img, label, _ in train_dataset], strict=True\n",
86+
")\n",
87+
"val_imgs, val_labels = zip(\n",
88+
" *[(img, label) for img, label, _ in val_dataset], strict=True\n",
89+
")"
90+
]
8491
},
8592
{
8693
"cell_type": "markdown",
@@ -176,7 +183,15 @@
176183
}
177184
},
178185
"outputs": [],
179-
"source": "train_paths, train_labels = zip(\n *[(path, label) for _, label, path in train_dataset], strict=True\n)\nencodings_vlad = vlad_encoder.generate_encoding_map(train_paths)\nencodings_fisher = fisher_vector_encoder.generate_encoding_map(train_paths)\nencodings_pipeline = pipeline_with_pca.generate_encoding_map(train_paths)\ndataset_labels_dict = dict(zip(train_paths, train_labels, strict=True))"
186+
"source": [
187+
"train_paths, train_labels = zip(\n",
188+
" *[(path, label) for _, label, path in train_dataset], strict=True\n",
189+
")\n",
190+
"encodings_vlad = vlad_encoder.generate_encoding_map(train_paths)\n",
191+
"encodings_fisher = fisher_vector_encoder.generate_encoding_map(train_paths)\n",
192+
"encodings_pipeline = pipeline_with_pca.generate_encoding_map(train_paths)\n",
193+
"dataset_labels_dict = dict(zip(train_paths, train_labels, strict=True))"
194+
]
180195
},
181196
{
182197
"cell_type": "markdown",
@@ -570,4 +585,4 @@
570585
},
571586
"nbformat": 4,
572587
"nbformat_minor": 5
573-
}
588+
}

pyvisim/_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def cosine_similarity(x: np.ndarray, y: np.ndarray) -> np.ndarray:
2929
x = x.reshape(1, -1) if len(x.shape) == 1 else x
3030
y = y.reshape(1, -1) if len(y.shape) == 1 else y
3131
if x.shape[-1] <= 1 or y.shape[-1] <= 1:
32-
raise ValueError(f"Cosine similarity requires at least 2 features. Got {x.shape[-1]} features for x and {y.shape[-1]} features for y.")
32+
raise ValueError(
33+
f"Cosine similarity requires at least 2 features. Got {x.shape[-1]} features for x and {y.shape[-1]} features for y."
34+
)
3335

3436
return cs(x, y)
3537

@@ -83,7 +85,10 @@ def cluster_and_return_labels(
8385
if n_clusters is None:
8486
raise ValueError("n_clusters must be specified for Spectral Clustering.")
8587
model = SpectralClustering(
86-
n_clusters=n_clusters, affinity="nearest_neighbors", random_state=42, **kwargs
88+
n_clusters=n_clusters,
89+
affinity="nearest_neighbors",
90+
random_state=42,
91+
**kwargs,
8792
)
8893
return model.fit_predict(data)
8994

@@ -153,9 +158,13 @@ def plot_and_save_heatmap(
153158
matrix = matrix.detach().cpu().numpy()
154159

155160
figsize = (
156-
matrix.shape[1] * 0.7,
157-
matrix.shape[0] * 0.7,
158-
) if figsize is None else figsize
161+
(
162+
matrix.shape[1] * 0.7,
163+
matrix.shape[0] * 0.7,
164+
)
165+
if figsize is None
166+
else figsize
167+
)
159168
plt.figure(figsize=figsize)
160169
sns.heatmap(
161170
matrix,

pyvisim/encoders/vlad.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def encode(self, images: Iterable[np.ndarray] | np.ndarray) -> np.ndarray:
103103
descriptors = self.pca.transform(descriptors.astype(np.float32))
104104

105105
if descriptors is None or descriptors.shape[0] == 0:
106-
raise ValueError("No descriptors found in the image. Cannot compute VLAD encoding.")
106+
raise ValueError(
107+
"No descriptors found in the image. Cannot compute VLAD encoding."
108+
)
107109

108110
labels = self.clustering_model.predict(descriptors.astype(np.float32))
109111
centroids = self.clustering_model.cluster_centers_

0 commit comments

Comments
 (0)