@@ -463,3 +463,230 @@ def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10
463463 if not preserve_batch :
464464 bce = jnp .mean (bce )
465465 return bce
466+
467+
468+ @partial (jit , static_argnums = [2 , 3 ])
469+ def _compute_contingency_table ( ## vectorized construction of contingency matrix
470+ labels_true : jnp .ndarray ,
471+ labels_pred : jnp .ndarray ,
472+ n_classes : int ,
473+ n_clusters : int
474+ ) -> jnp .ndarray :
475+ ## Computes a contingency matrix table
476+ ## This routine expects true integer labels and predicted integer labels (1D arrays of size N)
477+
478+ # Create indicator masks across all unique classes/clusters
479+ # find unique IDs safely up to a static maximum size (or provide num_classes)
480+ # n_classes = n_true = jnp.max(labels_true) + 1
481+ # n_clusters = n_pred = jnp.max(labels_pred) + 1
482+
483+ # Broadcast to form a full one-hot lookup map
484+ true_mask = labels_true [:, None ] == jnp .arange (n_classes )
485+ pred_mask = labels_pred [:, None ] == jnp .arange (n_clusters )
486+
487+ # Contingency matrix is the matrix product of boolean indicators
488+ contingency = jnp .dot (true_mask .T .astype (jnp .float32 ), pred_mask .astype (jnp .float32 ))
489+ return contingency
490+
491+
492+ def measure_ARI (
493+ labels_true : jnp .ndarray ,
494+ labels_pred : jnp .ndarray
495+ ) -> jnp .ndarray :
496+ """
497+ Computes the adjusted random index (ARI), which measures similarity between two
498+ sets of indices (ground truth against a clustering's produced indices) via counting the
499+ pairs of data points assigned to same or different clusters (adjusted for chance). This
500+ measurement lies in `[0, 1]`, where `0` indicates a random labeling/assignment and `1` indicates
501+ perfect agreement.
502+
503+ Args:
504+ labels_true: 1D array of shape (n_samples,) with true integer class labels.
505+
506+ labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels.
507+
508+ Returns:
509+ scalar ARI of these two sets of indices
510+ """
511+ ## Dynamically find dimensions up to a statically bounded maximum
512+ n_classes = int (jnp .max (labels_true ) + 1 )
513+ n_clusters = int (jnp .max (labels_pred ) + 1 )
514+ return _calc_adjusted_rand_index (labels_true , labels_pred , n_classes , n_clusters )
515+
516+
517+ @partial (jit , static_argnums = [2 , 3 ])
518+ def _calc_adjusted_rand_index ( ## ARI
519+ labels_true : jnp .ndarray ,
520+ labels_pred : jnp .ndarray ,
521+ n_classes : int ,
522+ n_clusters : int
523+ ) -> jnp .ndarray :
524+ n_samples = labels_true .shape [0 ]
525+ if n_samples <= 1 :
526+ return jnp .array (1.0 )
527+
528+ ## Get contingency matrix (n_classes x n_clusters)
529+ contingency = _compute_contingency_table (
530+ labels_true ,
531+ labels_pred ,
532+ n_classes ,
533+ n_clusters
534+ )
535+
536+ ## Calculate combination sums n_ijC2 = (n_ij * (n_ij - 1)) / 2
537+ sum_nij_c2 = jnp .sum ((contingency * (contingency - 1.0 )) / 2.0 )
538+
539+ ## Sums across margins (rows and columns)
540+ sum_a = jnp .sum (contingency , axis = 1 )
541+ sum_b = jnp .sum (contingency , axis = 0 )
542+
543+ ## Margin pair combinations
544+ sum_a_c2 = jnp .sum ((sum_a * (sum_a - 1.0 )) / 2.0 )
545+ sum_b_c2 = jnp .sum ((sum_b * (sum_b - 1.0 )) / 2.0 )
546+
547+ ## Expected index and Max index math formulas
548+ total_c2 = (n_samples * (n_samples - 1.0 )) / 2.0
549+ expected_index = (sum_a_c2 * sum_b_c2 ) / total_c2
550+ max_index = (sum_a_c2 + sum_b_c2 ) / 2.0
551+
552+ ## Prevent division by zero if everything is perfectly clustered or uniform
553+ denominator = max_index - expected_index
554+ ari = jnp .where (denominator == 0.0 , 1.0 , (sum_nij_c2 - expected_index ) / denominator )
555+ return ari
556+
557+
558+ def measure_FMI (
559+ labels_true : jnp .ndarray ,
560+ labels_pred : jnp .ndarray
561+ ) -> jnp .ndarray :
562+ """
563+ Calculates the Fowlkes-Mallows Index (FMI), which measures similarity between two sets of
564+ indices - this score is the geometric mean of pair-wise recall and precision.
565+ This measurement lies in `[0, 1]`, where higher is better (indicating greater similarity between
566+ two clustering sets of identifiers).
567+
568+ Args:
569+ labels_true: 1D array of shape (n_samples,) with true integer class labels.
570+
571+ labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels.
572+
573+ Returns:
574+ scalar FMI of these two sets of indices
575+ """
576+ ## Dynamically find dimensions up to a statically bounded maximum
577+ n_classes = int (jnp .max (labels_true ) + 1 )
578+ n_clusters = int (jnp .max (labels_pred ) + 1 )
579+ return _measure_fowlkes_mallows_index (labels_true , labels_pred , n_classes , n_clusters )
580+
581+
582+ @partial (jit , static_argnums = [2 , 3 ])
583+ def _measure_fowlkes_mallows_index ( ## FMI
584+ labels_true : jnp .ndarray ,
585+ labels_pred : jnp .ndarray ,
586+ n_classes : int ,
587+ n_clusters : int
588+ ) -> jnp .ndarray :
589+ n_samples = labels_true .shape [0 ]
590+ # Handle edge case for single or empty samples safely
591+ if n_samples <= 1 :
592+ return jnp .array (0.0 , dtype = jnp .float32 )
593+
594+ contingency = _compute_contingency_table (labels_true , labels_pred , n_classes , n_clusters )
595+
596+ ## Compute marginal sums (sums along rows and columns)
597+ sum_true = jnp .sum (contingency , axis = 1 )
598+ sum_pred = jnp .sum (contingency , axis = 0 )
599+
600+ ## Calculate pairwise combinations using the matrix shortcut: nC2 = 0.5 * (sum(x^2) - N)
601+ # True Positives pair combinations (tk)
602+ tk = 0.5 * (jnp .sum (contingency ** 2 ) - n_samples )
603+ ## Total pairs clustered together in ground truth (tr)
604+ tr = 0.5 * (jnp .sum (sum_true ** 2 ) - n_samples )
605+ ## Total pairs clustered together in predictions (tc)
606+ tc = 0.5 * (jnp .sum (sum_pred ** 2 ) - n_samples )
607+
608+ ## Compute FMI = tk / sqrt(tr * tc)
609+ # Prevent division by zero if there are no pair splits/matches
610+ denominator = jnp .sqrt (tr * tc )
611+ fmi = jnp .where (denominator == 0.0 , 0.0 , tk / denominator )
612+ return fmi
613+
614+
615+ def measure_Vmeasure ( ## V-Measure
616+ labels_true : jnp .ndarray ,
617+ labels_pred : jnp .ndarray ,
618+ beta : float = 1.0
619+ ) -> jnp .ndarray :
620+ """
621+ Calculates the V-Measure scoring metric for class conformity. This measurement compares
622+ predicted cluster indices ("labels_pred") against ground truth indices ("labels_true") and
623+ represents the harmonic mean of homogeneity (where each cluster contains only members of a single class)
624+ as well as completeness (where all members of a given class are assigned to the same cluster).
625+ This measurement (higher is better) lies in `[0,1]` where `1` indicates perfect, correct clustering.
626+
627+ Args:
628+ labels_true: 1D array of shape (n_samples,) with true integer class labels
629+
630+ labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels
631+
632+ beta: Weight factor. Ratios > 1.0 favor completeness, < 1.0 favor homogeneity.
633+
634+ Returns:
635+ scalar V-measure of these two sets of indices
636+ """
637+ ## Dynamically find dimensions up to a statically bounded maximum
638+ n_classes = int (jnp .max (labels_true ) + 1 )
639+ n_clusters = int (jnp .max (labels_pred ) + 1 )
640+ return _measure_v_measure_score (labels_true , labels_pred , n_classes , n_clusters , beta )
641+
642+
643+ @partial (jit , static_argnums = [2 , 3 , 4 ])
644+ def _measure_v_measure_score ( ## V-Measure
645+ labels_true : jnp .ndarray ,
646+ labels_pred : jnp .ndarray ,
647+ n_classes : int ,
648+ n_clusters : int ,
649+ beta : float = 1.0
650+ ) -> jnp .ndarray :
651+ n_samples = labels_true .shape [0 ]
652+
653+ ## Handle edge case for single or empty samples safely
654+ if n_samples <= 1 :
655+ return jnp .array (0.0 , dtype = jnp .float32 )
656+
657+ contingency = _compute_contingency_table (labels_true , labels_pred , n_classes , n_clusters )
658+
659+ ## Calculate Marginal Sums (Row and Column totals)
660+ sum_true = jnp .sum (contingency , axis = 1 )
661+ sum_pred = jnp .sum (contingency , axis = 0 )
662+
663+ ## Compute Base Entropies H(True) and H(Pred)
664+ p_true = sum_true / n_samples
665+ h_true = - jnp .sum (jnp .where (p_true > 0.0 , p_true * jnp .log (p_true ), 0.0 ))
666+
667+ p_pred = sum_pred / n_samples
668+ h_pred = - jnp .sum (jnp .where (p_pred > 0.0 , p_pred * jnp .log (p_pred ), 0.0 ))
669+
670+ ## Compute Joint Entropy H(True, Pred)
671+ p_joint = contingency / n_samples
672+ h_joint = - jnp .sum (jnp .where (p_joint > 0.0 , p_joint * jnp .log (p_joint ), 0.0 ))
673+
674+ ## Derive Conditional Entropies: H(True|Pred) and H(Pred|True) using identity rule
675+ h_true_given_pred = h_joint - h_pred
676+ h_pred_given_true = h_joint - h_true
677+
678+ ## Compute Homogeneity (H) and Completeness (C)
679+ ## If base entropy is 0, the metric is perfectly satisfied (1.0)
680+ homogeneity = jnp .where (h_true == 0.0 , 1.0 , 1.0 - (h_true_given_pred / h_true ))
681+ completeness = jnp .where (h_pred == 0.0 , 1.0 , 1.0 - (h_pred_given_true / h_pred ))
682+
683+ ## Compute Weighted Harmonic Mean (V-Measure)
684+ denominator = beta * homogeneity + completeness
685+
686+ ## Prevent division by zero if both metrics are zero
687+ v_measure = jnp .where (
688+ denominator == 0.0 ,
689+ 0.0 ,
690+ (1.0 + beta ) * homogeneity * completeness / denominator
691+ )
692+ return v_measure
0 commit comments