@@ -79,67 +79,30 @@ def __init__(self, dataset, **kwargs):
7979 random_state = kwargs .pop ('random_state' , None )
8080 self .random_state_ = seed_random_state (random_state )
8181
82- @inherit_docstring_from (QueryStrategy )
83- def make_query (self ):
84- dataset = self .dataset
85- unlabeled_entry_ids , X_pool = dataset .get_unlabeled_entries ()
86- X_pool = np .asarray (X_pool )
87-
88- if len (unlabeled_entry_ids ) == 0 :
89- raise ValueError ("No unlabeled samples available" )
90-
91- # Get labeled data
92- labeled_entries = dataset .get_labeled_entries ()
93- X_labeled = np .asarray (labeled_entries [0 ])
94-
95- # Fallback to random if no labeled data
96- if len (X_labeled ) == 0 :
97- idx = self .random_state_ .randint (0 , len (unlabeled_entry_ids ))
98- return unlabeled_entry_ids [idx ]
99-
100- # Transform features if transformer is provided
101- if self .transformer is not None :
102- X_pool_t = np .asarray (self .transformer .transform (X_pool ))
103- X_labeled_t = np .asarray (self .transformer .transform (X_labeled ))
104- else :
105- X_pool_t = X_pool
106- X_labeled_t = X_labeled
107-
108- # Compute pairwise distances: (n_unlabeled, n_labeled)
109- dist_matrix = cdist (X_pool_t , X_labeled_t , metric = self .metric )
110-
111- # For each unlabeled point, find minimum distance to any labeled point
112- min_distances = np .min (dist_matrix , axis = 1 )
113-
114- # Select the unlabeled point with maximum min-distance (farthest)
115- max_dist = np .max (min_distances )
116- candidates = np .where (np .isclose (min_distances , max_dist ))[0 ]
117- selected_idx = self .random_state_ .choice (candidates )
118-
119- return unlabeled_entry_ids [selected_idx ]
120-
12182 def _get_scores (self ):
12283 """Return min-distances to labeled set for all unlabeled samples.
12384
12485 Returns
12586 -------
126- scores : list of (entry_id, score) tuples
127- Each score is the minimum distance from that unlabeled point
128- to any labeled point. Higher score means more informative.
87+ entry_ids : np.ndarray, shape (n_unlabeled,)
88+ Global entry IDs of unlabeled samples.
89+ scores : np.ndarray, shape (n_unlabeled,)
90+ Min-distance from each unlabeled point to any labeled point.
91+ Higher score means more informative.
12992 """
13093 dataset = self .dataset
13194 unlabeled_entry_ids , X_pool = dataset .get_unlabeled_entries ()
13295 X_pool = np .asarray (X_pool )
13396
13497 if len (unlabeled_entry_ids ) == 0 :
135- return []
98+ return np . array ([], dtype = int ), np . array ([], dtype = float )
13699
137100 labeled_entries = dataset .get_labeled_entries ()
138101 X_labeled = np .asarray (labeled_entries [0 ])
139102
140103 if len (X_labeled ) == 0 :
141- return list ( zip ( unlabeled_entry_ids ,
142- [ float ('inf' )] * len ( unlabeled_entry_ids ) ))
104+ return np . asarray ( unlabeled_entry_ids ), \
105+ np . full ( len ( unlabeled_entry_ids ), float ('inf' ))
143106
144107 if self .transformer is not None :
145108 X_pool_t = np .asarray (self .transformer .transform (X_pool ))
@@ -151,4 +114,23 @@ def _get_scores(self):
151114 dist_matrix = cdist (X_pool_t , X_labeled_t , metric = self .metric )
152115 min_distances = np .min (dist_matrix , axis = 1 )
153116
154- return list (zip (unlabeled_entry_ids , min_distances ))
117+ return np .asarray (unlabeled_entry_ids ), min_distances
118+
119+ @inherit_docstring_from (QueryStrategy )
120+ def make_query (self ):
121+ unlabeled_entry_ids , min_distances = self ._get_scores ()
122+
123+ if len (unlabeled_entry_ids ) == 0 :
124+ raise ValueError ("No unlabeled samples available" )
125+
126+ # Fallback to random if no labeled data (scores are all inf)
127+ if np .all (np .isinf (min_distances )):
128+ idx = self .random_state_ .randint (0 , len (unlabeled_entry_ids ))
129+ return unlabeled_entry_ids [idx ]
130+
131+ # Select the unlabeled point with maximum min-distance (farthest)
132+ max_dist = np .max (min_distances )
133+ candidates = np .where (np .isclose (min_distances , max_dist ))[0 ]
134+ selected_idx = self .random_state_ .choice (candidates )
135+
136+ return unlabeled_entry_ids [selected_idx ]
0 commit comments