Skip to content

Commit 22ba7aa

Browse files
author
Alexander Ororbia
committed
added class-conformity metrics in metric_utils; integrated kmeans-probe in utils.analysis
1 parent 1e3073a commit 22ba7aa

5 files changed

Lines changed: 363 additions & 26 deletions

File tree

ngclearn/utils/analysis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from .linear_probe import LinearProbe
33
from .attentive_probe import AttentiveProbe
44
from .knn_probe import KNNProbe
5-
5+
from .kmeans_probe import KMeansProbe
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import jax
2+
from ngcsimlib import deprecate_args
3+
from ngclearn.utils.analysis.probe import Probe
4+
from jax import jit, random, numpy as jnp, lax, nn
5+
from functools import partial as bind
6+
from ngclearn.utils.metric_utils import measure_ARI
7+
8+
@bind(jax.jit, static_argnums=[2])
9+
def _run_kmeans_probe(_embeddings, centroids, n_clusters):
10+
## Broadcast distances: (n_samples, 1, n_features) - (1, n_clusters, n_features)
11+
distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1)
12+
labels_pred = jnp.argmin(distances, axis=1)
13+
## Re-estimate centroids/means
14+
one_hot_preds = labels_pred[:, None] == jnp.arange(n_clusters)
15+
counts = jnp.maximum(one_hot_preds.sum(axis=0, keepdims=True).T, 1.0)
16+
centroids = jnp.dot(one_hot_preds.T.astype(jnp.float32), _embeddings) / counts
17+
return centroids
18+
19+
@bind(jax.jit, static_argnums=[2])
20+
def _predict_with_probe(_embeddings, centroids, n_clusters):
21+
## Final pass to compute stable predictions
22+
distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1)
23+
labels_pred = jnp.argmin(distances, axis=1)
24+
Y_pred = nn.one_hot(labels_pred, n_clusters)
25+
return labels_pred, Y_pred
26+
27+
class KMeansProbe(Probe):
28+
"""
29+
This implements a K-means clustering probe, which is useful for evaluating the quality of
30+
encodings/embeddings in light of the ability to cluster downstream data. Currently, this
31+
probe only supports L2/Euclidean distance-based clustering.
32+
33+
Args:
34+
dkey: init seed key
35+
36+
source_seq_length: length of input sequence (e.g., height x width of the image feature)
37+
38+
input_dim: input dimensionality of probe
39+
40+
out_dim: output dimensionality of probe - number of clusters for this probe to create
41+
42+
batch_size: <Unused>
43+
44+
"""
45+
46+
def __init__(
47+
self,
48+
dkey,
49+
source_seq_length,
50+
input_dim,
51+
out_dim=2, ## number of clusters/centroids to uncover
52+
batch_size=1,
53+
**kwargs
54+
):
55+
super().__init__(dkey, batch_size, **kwargs)
56+
self.dkey, *subkeys = random.split(self.dkey, 3)
57+
self.source_seq_length = source_seq_length
58+
self.input_dim = input_dim
59+
self.n_clusters = self.out_dim = out_dim
60+
## centroids that will be uncovered by this probe
61+
self.centroids : jax.Array = None
62+
63+
def _init(self, embeddings):
64+
_embeddings = embeddings
65+
if len(_embeddings.shape) > 2:
66+
flat_dim = embeddings.shape[1] * embeddings.shape[2]
67+
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
68+
## choose random data-points to serve as centroids at iteration 0
69+
self.dkey, *subkeys = random.split(self.dkey, 15)
70+
n_samples, n_features = _embeddings.shape
71+
random_indices = random.choice(
72+
subkeys[0], n_samples, shape=(self.n_clusters,), replace=False
73+
)
74+
self.centroids = _embeddings[random_indices]
75+
76+
def process(self, embeddings, dkey=None):
77+
_embeddings = embeddings
78+
if len(_embeddings.shape) > 2:
79+
flat_dim = embeddings.shape[1] * embeddings.shape[2]
80+
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
81+
## Compute final geometric vs semantic conformity via ARI
82+
_, Y_pred = _predict_with_probe(_embeddings, self.centroids, self.n_clusters)
83+
return Y_pred ## (B, C)
84+
85+
def update(self, embeddings, labels, dkey=None):
86+
_embeddings = embeddings
87+
if len(_embeddings.shape) > 2:
88+
flat_dim = embeddings.shape[1] * embeddings.shape[2]
89+
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
90+
self.centroids = _run_kmeans_probe(_embeddings, self.centroids, self.n_clusters)
91+
L = 0. ## FIXME: should be clustering loss
92+
predictions = self.process(_embeddings)
93+
return L, predictions
94+
95+
def fit(self, dataset, dev_dataset=None, n_iter=20, patience=20):
96+
data, labels = dataset
97+
_labels = jnp.argmax(labels, axis=-1)
98+
99+
self._init(data) ## init K-means centroids
100+
ari = 0.
101+
for i in range(n_iter): ## Run vectorized K-Means optimization loop
102+
_L, py = self.update(data, labels)
103+
labels_pred = jnp.argmax(py, axis=1)
104+
ari_i = measure_ARI(_labels, labels_pred)
105+
print(f"\r{i}: ARI = {ari_i}", end="")
106+
if ari_i > ari:
107+
ari = ari_i
108+
print()
109+
return ari

ngclearn/utils/analysis/knn_probe.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -139,25 +139,25 @@ def update(self, embeddings, labels, dkey=None):
139139
Wy = labels
140140
self.probe_params = (Wx, Wy)
141141

142-
if __name__ == '__main__':
143-
seed = 42
144-
D = 7
145-
C = 5
146-
dkey = random.PRNGKey(seed)
147-
dkey, *subkeys = random.split(dkey, 3)
148-
knn = KNNProbe(
149-
subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean"
150-
)
151-
X = random.uniform(subkeys[1], shape=(10, D))
152-
Y = jnp.concat(
153-
[
154-
jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]),
155-
jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]),
156-
jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]),
157-
jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]),
158-
jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]])
159-
],
160-
axis=0
161-
)
162-
knn.update(X, Y) ## fit KNN to data
163-
print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y
142+
# if __name__ == '__main__':
143+
# seed = 42
144+
# D = 7
145+
# C = 5
146+
# dkey = random.PRNGKey(seed)
147+
# dkey, *subkeys = random.split(dkey, 3)
148+
# knn = KNNProbe(
149+
# subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean"
150+
# )
151+
# X = random.uniform(subkeys[1], shape=(10, D))
152+
# Y = jnp.concat(
153+
# [
154+
# jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]),
155+
# jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]),
156+
# jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]),
157+
# jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]),
158+
# jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]])
159+
# ],
160+
# axis=0
161+
# )
162+
# knn.update(X, Y) ## fit KNN to data
163+
# print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y

ngclearn/utils/metric_utils.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,3 +463,230 @@ def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10
463463
if not preserve_batch:
464464
bce = jnp.mean(bce)
465465
return bce
466+
467+
468+
@partial(jit, static_argnums=[2, 3])
469+
def _compute_contingency_table( ## vectorized construction of contingency matrix
470+
labels_true: jnp.ndarray,
471+
labels_pred: jnp.ndarray,
472+
n_classes: int,
473+
n_clusters: int
474+
) -> jnp.ndarray:
475+
## Computes a contingency matrix table
476+
## This routine expects true integer labels and predicted integer labels (1D arrays of size N)
477+
478+
# Create indicator masks across all unique classes/clusters
479+
# find unique IDs safely up to a static maximum size (or provide num_classes)
480+
# n_classes = n_true = jnp.max(labels_true) + 1
481+
# n_clusters = n_pred = jnp.max(labels_pred) + 1
482+
483+
# Broadcast to form a full one-hot lookup map
484+
true_mask = labels_true[:, None] == jnp.arange(n_classes)
485+
pred_mask = labels_pred[:, None] == jnp.arange(n_clusters)
486+
487+
# Contingency matrix is the matrix product of boolean indicators
488+
contingency = jnp.dot(true_mask.T.astype(jnp.float32), pred_mask.astype(jnp.float32))
489+
return contingency
490+
491+
492+
def measure_ARI(
493+
labels_true: jnp.ndarray,
494+
labels_pred: jnp.ndarray
495+
) -> jnp.ndarray:
496+
"""
497+
Computes the adjusted random index (ARI), which measures similarity between two
498+
sets of indices (ground truth against a clustering's produced indices) via counting the
499+
pairs of data points assigned to same or different clusters (adjusted for chance). This
500+
measurement lies in `[0, 1]`, where `0` indicates a random labeling/assignment and `1` indicates
501+
perfect agreement.
502+
503+
Args:
504+
labels_true: 1D array of shape (n_samples,) with true integer class labels.
505+
506+
labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels.
507+
508+
Returns:
509+
scalar ARI of these two sets of indices
510+
"""
511+
## Dynamically find dimensions up to a statically bounded maximum
512+
n_classes = int(jnp.max(labels_true) + 1)
513+
n_clusters = int(jnp.max(labels_pred) + 1)
514+
return _calc_adjusted_rand_index(labels_true, labels_pred, n_classes, n_clusters)
515+
516+
517+
@partial(jit, static_argnums=[2, 3])
518+
def _calc_adjusted_rand_index( ## ARI
519+
labels_true: jnp.ndarray,
520+
labels_pred: jnp.ndarray,
521+
n_classes: int,
522+
n_clusters: int
523+
) -> jnp.ndarray:
524+
n_samples = labels_true.shape[0]
525+
if n_samples <= 1:
526+
return jnp.array(1.0)
527+
528+
## Get contingency matrix (n_classes x n_clusters)
529+
contingency = _compute_contingency_table(
530+
labels_true,
531+
labels_pred,
532+
n_classes,
533+
n_clusters
534+
)
535+
536+
## Calculate combination sums n_ijC2 = (n_ij * (n_ij - 1)) / 2
537+
sum_nij_c2 = jnp.sum((contingency * (contingency - 1.0)) / 2.0)
538+
539+
## Sums across margins (rows and columns)
540+
sum_a = jnp.sum(contingency, axis=1)
541+
sum_b = jnp.sum(contingency, axis=0)
542+
543+
## Margin pair combinations
544+
sum_a_c2 = jnp.sum((sum_a * (sum_a - 1.0)) / 2.0)
545+
sum_b_c2 = jnp.sum((sum_b * (sum_b - 1.0)) / 2.0)
546+
547+
## Expected index and Max index math formulas
548+
total_c2 = (n_samples * (n_samples - 1.0)) / 2.0
549+
expected_index = (sum_a_c2 * sum_b_c2) / total_c2
550+
max_index = (sum_a_c2 + sum_b_c2) / 2.0
551+
552+
## Prevent division by zero if everything is perfectly clustered or uniform
553+
denominator = max_index - expected_index
554+
ari = jnp.where(denominator == 0.0, 1.0, (sum_nij_c2 - expected_index) / denominator)
555+
return ari
556+
557+
558+
def measure_FMI(
559+
labels_true: jnp.ndarray,
560+
labels_pred: jnp.ndarray
561+
) -> jnp.ndarray:
562+
"""
563+
Calculates the Fowlkes-Mallows Index (FMI), which measures similarity between two sets of
564+
indices - this score is the geometric mean of pair-wise recall and precision.
565+
This measurement lies in `[0, 1]`, where higher is better (indicating greater similarity between
566+
two clustering sets of identifiers).
567+
568+
Args:
569+
labels_true: 1D array of shape (n_samples,) with true integer class labels.
570+
571+
labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels.
572+
573+
Returns:
574+
scalar FMI of these two sets of indices
575+
"""
576+
## Dynamically find dimensions up to a statically bounded maximum
577+
n_classes = int(jnp.max(labels_true) + 1)
578+
n_clusters = int(jnp.max(labels_pred) + 1)
579+
return _measure_fowlkes_mallows_index(labels_true, labels_pred, n_classes, n_clusters)
580+
581+
582+
@partial(jit, static_argnums=[2, 3])
583+
def _measure_fowlkes_mallows_index( ## FMI
584+
labels_true: jnp.ndarray,
585+
labels_pred: jnp.ndarray,
586+
n_classes: int,
587+
n_clusters: int
588+
) -> jnp.ndarray:
589+
n_samples = labels_true.shape[0]
590+
# Handle edge case for single or empty samples safely
591+
if n_samples <= 1:
592+
return jnp.array(0.0, dtype=jnp.float32)
593+
594+
contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters)
595+
596+
## Compute marginal sums (sums along rows and columns)
597+
sum_true = jnp.sum(contingency, axis=1)
598+
sum_pred = jnp.sum(contingency, axis=0)
599+
600+
## Calculate pairwise combinations using the matrix shortcut: nC2 = 0.5 * (sum(x^2) - N)
601+
# True Positives pair combinations (tk)
602+
tk = 0.5 * (jnp.sum(contingency ** 2) - n_samples)
603+
## Total pairs clustered together in ground truth (tr)
604+
tr = 0.5 * (jnp.sum(sum_true ** 2) - n_samples)
605+
## Total pairs clustered together in predictions (tc)
606+
tc = 0.5 * (jnp.sum(sum_pred ** 2) - n_samples)
607+
608+
## Compute FMI = tk / sqrt(tr * tc)
609+
# Prevent division by zero if there are no pair splits/matches
610+
denominator = jnp.sqrt(tr * tc)
611+
fmi = jnp.where(denominator == 0.0, 0.0, tk / denominator)
612+
return fmi
613+
614+
615+
def measure_Vmeasure( ## V-Measure
616+
labels_true: jnp.ndarray,
617+
labels_pred: jnp.ndarray,
618+
beta: float = 1.0
619+
) -> jnp.ndarray:
620+
"""
621+
Calculates the V-Measure scoring metric for class conformity. This measurement compares
622+
predicted cluster indices ("labels_pred") against ground truth indices ("labels_true") and
623+
represents the harmonic mean of homogeneity (where each cluster contains only members of a single class)
624+
as well as completeness (where all members of a given class are assigned to the same cluster).
625+
This measurement (higher is better) lies in `[0,1]` where `1` indicates perfect, correct clustering.
626+
627+
Args:
628+
labels_true: 1D array of shape (n_samples,) with true integer class labels
629+
630+
labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels
631+
632+
beta: Weight factor. Ratios > 1.0 favor completeness, < 1.0 favor homogeneity.
633+
634+
Returns:
635+
scalar V-measure of these two sets of indices
636+
"""
637+
## Dynamically find dimensions up to a statically bounded maximum
638+
n_classes = int(jnp.max(labels_true) + 1)
639+
n_clusters = int(jnp.max(labels_pred) + 1)
640+
return _measure_v_measure_score(labels_true, labels_pred, n_classes, n_clusters, beta)
641+
642+
643+
@partial(jit, static_argnums=[2, 3, 4])
644+
def _measure_v_measure_score( ## V-Measure
645+
labels_true: jnp.ndarray,
646+
labels_pred: jnp.ndarray,
647+
n_classes: int,
648+
n_clusters: int,
649+
beta: float = 1.0
650+
) -> jnp.ndarray:
651+
n_samples = labels_true.shape[0]
652+
653+
## Handle edge case for single or empty samples safely
654+
if n_samples <= 1:
655+
return jnp.array(0.0, dtype=jnp.float32)
656+
657+
contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters)
658+
659+
## Calculate Marginal Sums (Row and Column totals)
660+
sum_true = jnp.sum(contingency, axis=1)
661+
sum_pred = jnp.sum(contingency, axis=0)
662+
663+
## Compute Base Entropies H(True) and H(Pred)
664+
p_true = sum_true / n_samples
665+
h_true = -jnp.sum(jnp.where(p_true > 0.0, p_true * jnp.log(p_true), 0.0))
666+
667+
p_pred = sum_pred / n_samples
668+
h_pred = -jnp.sum(jnp.where(p_pred > 0.0, p_pred * jnp.log(p_pred), 0.0))
669+
670+
## Compute Joint Entropy H(True, Pred)
671+
p_joint = contingency / n_samples
672+
h_joint = -jnp.sum(jnp.where(p_joint > 0.0, p_joint * jnp.log(p_joint), 0.0))
673+
674+
## Derive Conditional Entropies: H(True|Pred) and H(Pred|True) using identity rule
675+
h_true_given_pred = h_joint - h_pred
676+
h_pred_given_true = h_joint - h_true
677+
678+
## Compute Homogeneity (H) and Completeness (C)
679+
## If base entropy is 0, the metric is perfectly satisfied (1.0)
680+
homogeneity = jnp.where(h_true == 0.0, 1.0, 1.0 - (h_true_given_pred / h_true))
681+
completeness = jnp.where(h_pred == 0.0, 1.0, 1.0 - (h_pred_given_true / h_pred))
682+
683+
## Compute Weighted Harmonic Mean (V-Measure)
684+
denominator = beta * homogeneity + completeness
685+
686+
## Prevent division by zero if both metrics are zero
687+
v_measure = jnp.where(
688+
denominator == 0.0,
689+
0.0,
690+
(1.0 + beta) * homogeneity * completeness / denominator
691+
)
692+
return v_measure

0 commit comments

Comments
 (0)