1+ #!/usr/bin/env python3
2+ """
3+ Multi-Label Adaptive Classifier Example
4+
5+ This example demonstrates how to use the MultiLabelAdaptiveClassifier
6+ for text classification tasks where each text can belong to multiple categories.
7+
8+ Key features demonstrated:
9+ 1. Training with multi-label data
10+ 2. Making multi-label predictions
11+ 3. Adaptive threshold handling for many labels
12+ 4. Label-specific threshold customization
13+ 5. Saving and loading multi-label models
14+ """
15+
16+ import sys
17+ import os
18+ sys .path .insert (0 , os .path .join (os .path .dirname (__file__ ), '..' , 'src' ))
19+
20+ from adaptive_classifier import MultiLabelAdaptiveClassifier
21+ import torch
22+
23+
24+ def create_sample_data ():
25+ """Create sample multi-label training data."""
26+
27+ # Sample texts with multiple labels each
28+ training_data = [
29+ {
30+ "text" : "Scientists discover new species of butterfly in Amazon rainforest with unique wing patterns" ,
31+ "labels" : ["science" , "nature" , "discovery" , "biology" ]
32+ },
33+ {
34+ "text" : "Tech startup raises $50M in Series A funding to develop AI-powered healthcare solutions" ,
35+ "labels" : ["technology" , "business" , "healthcare" , "funding" ]
36+ },
37+ {
38+ "text" : "Climate change impacts ocean temperature causing coral bleaching in Great Barrier Reef" ,
39+ "labels" : ["environment" , "climate" , "nature" , "science" ]
40+ },
41+ {
42+ "text" : "NBA playoffs feature exciting games with record-breaking performances by star players" ,
43+ "labels" : ["sports" , "entertainment" , "basketball" ]
44+ },
45+ {
46+ "text" : "New renewable energy technology could reduce costs by 40% according to MIT research" ,
47+ "labels" : ["technology" , "science" , "environment" , "energy" ]
48+ },
49+ {
50+ "text" : "Archaeological team uncovers 2000-year-old Roman artifacts in excavation site" ,
51+ "labels" : ["history" , "science" , "discovery" , "archaeology" ]
52+ },
53+ {
54+ "text" : "Stock market reaches new highs as investors show confidence in economic recovery" ,
55+ "labels" : ["business" , "finance" , "economy" ]
56+ },
57+ {
58+ "text" : "Machine learning breakthrough helps doctors diagnose rare diseases more accurately" ,
59+ "labels" : ["technology" , "healthcare" , "science" , "ai" ]
60+ },
61+ {
62+ "text" : "Wildlife conservation efforts show success in protecting endangered tiger populations" ,
63+ "labels" : ["nature" , "environment" , "conservation" , "wildlife" ]
64+ },
65+ {
66+ "text" : "Olympic athletes prepare for upcoming games with intensive training programs" ,
67+ "labels" : ["sports" , "olympics" , "training" , "fitness" ]
68+ },
69+ {
70+ "text" : "Quantum computing research makes progress toward solving complex optimization problems" ,
71+ "labels" : ["technology" , "science" , "computing" , "research" ]
72+ },
73+ {
74+ "text" : "Sustainable agriculture practices help farmers reduce environmental impact while increasing yield" ,
75+ "labels" : ["environment" , "agriculture" , "sustainability" , "farming" ]
76+ },
77+ {
78+ "text" : "Music festival features artists from diverse genres attracting thousands of fans" ,
79+ "labels" : ["entertainment" , "music" , "culture" , "events" ]
80+ },
81+ {
82+ "text" : "Space agency announces plans for Mars mission with new rocket technology" ,
83+ "labels" : ["science" , "space" , "technology" , "exploration" ]
84+ },
85+ {
86+ "text" : "Educational technology helps students learn programming through interactive online courses" ,
87+ "labels" : ["education" , "technology" , "programming" , "learning" ]
88+ }
89+ ]
90+
91+ # Extract texts and labels
92+ texts = [item ["text" ] for item in training_data ]
93+ labels = [item ["labels" ] for item in training_data ]
94+
95+ return texts , labels
96+
97+
98+ def demonstrate_basic_usage ():
99+ """Demonstrate basic multi-label classification."""
100+
101+ print ("=" * 60 )
102+ print ("MULTI-LABEL ADAPTIVE CLASSIFIER - BASIC USAGE" )
103+ print ("=" * 60 )
104+
105+ # Create classifier
106+ classifier = MultiLabelAdaptiveClassifier (
107+ model_name = "distilbert/distilbert-base-cased" ,
108+ default_threshold = 0.5 ,
109+ min_predictions = 1 , # Ensure at least 1 prediction
110+ max_predictions = 5 # Limit to top 5 predictions
111+ )
112+
113+ # Load training data
114+ texts , labels = create_sample_data ()
115+
116+ print (f"Training with { len (texts )} examples" )
117+ print (f"Example text: { texts [0 ][:60 ]} ..." )
118+ print (f"Example labels: { labels [0 ]} " )
119+
120+ # Train the classifier
121+ classifier .add_examples (texts , labels )
122+
123+ # Get statistics
124+ stats = classifier .get_label_statistics ()
125+ print (f"\n Training completed:" )
126+ print (f"- Total labels: { stats ['num_classes' ]} " )
127+ print (f"- Total examples: { stats ['total_examples' ]} " )
128+ print (f"- Adaptive threshold: { stats ['adaptive_threshold' ]:.3f} " )
129+
130+ return classifier
131+
132+
133+ def demonstrate_predictions (classifier ):
134+ """Demonstrate making predictions."""
135+
136+ print ("\n " + "=" * 60 )
137+ print ("MAKING PREDICTIONS" )
138+ print ("=" * 60 )
139+
140+ # Test texts
141+ test_texts = [
142+ "Researchers develop new AI algorithm for medical diagnosis" ,
143+ "Football team wins championship in exciting final match" ,
144+ "Solar panel efficiency increases with new manufacturing technique" ,
145+ "Ancient civilization discovered through satellite imagery analysis"
146+ ]
147+
148+ for text in test_texts :
149+ print (f"\n Text: { text } " )
150+
151+ # Make multi-label prediction
152+ predictions = classifier .predict_multilabel (text )
153+
154+ print ("Predictions:" )
155+ if predictions :
156+ for label , confidence in predictions :
157+ print (f" { label } : { confidence :.4f} " )
158+ else :
159+ print (" No predictions above threshold" )
160+
161+ return test_texts
162+
163+
164+ def demonstrate_threshold_adjustment (classifier ):
165+ """Demonstrate threshold adjustment for different scenarios."""
166+
167+ print ("\n " + "=" * 60 )
168+ print ("THRESHOLD ADJUSTMENT" )
169+ print ("=" * 60 )
170+
171+ test_text = "AI researchers publish breakthrough study on climate modeling using machine learning"
172+
173+ print (f"Test text: { test_text } " )
174+
175+ # Try different thresholds
176+ thresholds = [0.1 , 0.3 , 0.5 , 0.7 , 0.9 ]
177+
178+ print (f"\n { 'Threshold' :<10} { 'Predictions' :<12} { 'Labels' } " )
179+ print ("-" * 50 )
180+
181+ for threshold in thresholds :
182+ predictions = classifier .predict_multilabel (test_text , threshold = threshold )
183+ labels_str = ", " .join ([label for label , _ in predictions [:3 ]])
184+
185+ print (f"{ threshold :<10.1f} { len (predictions ):<12} { labels_str } " )
186+
187+
188+ def demonstrate_saving_loading (classifier ):
189+ """Demonstrate saving and loading the model."""
190+
191+ print ("\n " + "=" * 60 )
192+ print ("SAVING AND LOADING" )
193+ print ("=" * 60 )
194+
195+ # Save the model
196+ save_path = "./multilabel_classifier"
197+ print (f"Saving classifier to { save_path } " )
198+ classifier .save (save_path )
199+
200+ # Load the model
201+ print ("Loading classifier..." )
202+ loaded_classifier = MultiLabelAdaptiveClassifier .load (save_path )
203+
204+ # Verify it works
205+ test_text = "New medical technology helps treat cancer patients"
206+
207+ print (f"\n Testing loaded classifier:" )
208+ print (f"Text: { test_text } " )
209+
210+ predictions = loaded_classifier .predict_multilabel (test_text )
211+ print ("Predictions:" )
212+ for label , confidence in predictions :
213+ print (f" { label } : { confidence :.4f} " )
214+
215+ return loaded_classifier
216+
217+
218+ def demonstrate_incremental_learning (classifier ):
219+ """Demonstrate adding new labels incrementally."""
220+
221+ print ("\n " + "=" * 60 )
222+ print ("INCREMENTAL LEARNING - ADDING NEW LABELS" )
223+ print ("=" * 60 )
224+
225+ # Add new examples with new labels
226+ new_texts = [
227+ "Chef creates innovative fusion cuisine combining Asian and European flavors" ,
228+ "Food delivery service expands to new cities with sustainable packaging" ,
229+ "Restaurant industry adapts to new dining trends post-pandemic" ,
230+ "Cooking show features celebrity chefs competing in culinary challenges"
231+ ]
232+
233+ new_labels = [
234+ ["food" , "cuisine" , "cooking" , "culture" ],
235+ ["business" , "food" , "sustainability" ],
236+ ["business" , "food" , "trends" ],
237+ ["entertainment" , "food" , "cooking" , "tv" ]
238+ ]
239+
240+ print ("Adding new examples with 'food' and 'cooking' labels..." )
241+ classifier .add_examples (new_texts , new_labels )
242+
243+ # Test with food-related text
244+ food_text = "Nutritionist recommends healthy meal planning for busy professionals"
245+
246+ print (f"\n Testing with food-related text:" )
247+ print (f"Text: { food_text } " )
248+
249+ predictions = classifier .predict_multilabel (food_text )
250+ print ("Predictions:" )
251+ for label , confidence in predictions :
252+ print (f" { label } : { confidence :.4f} " )
253+
254+ # Show updated statistics
255+ stats = classifier .get_label_statistics ()
256+ print (f"\n Updated statistics:" )
257+ print (f"- Total labels: { stats ['num_classes' ]} " )
258+ print (f"- Total examples: { stats ['total_examples' ]} " )
259+
260+
261+ def main ():
262+ """Main function to run all demonstrations."""
263+
264+ print ("Multi-Label Adaptive Classifier Example" )
265+ print ("Fixing the 'No labels met the threshold criteria' issue\n " )
266+
267+ try :
268+ # Basic usage
269+ classifier = demonstrate_basic_usage ()
270+
271+ # Making predictions
272+ demonstrate_predictions (classifier )
273+
274+ # Threshold adjustment
275+ demonstrate_threshold_adjustment (classifier )
276+
277+ # Saving and loading
278+ loaded_classifier = demonstrate_saving_loading (classifier )
279+
280+ # Incremental learning
281+ demonstrate_incremental_learning (loaded_classifier )
282+
283+ print ("\n " + "=" * 60 )
284+ print ("EXAMPLE COMPLETED SUCCESSFULLY" )
285+ print ("=" * 60 )
286+
287+ # Final statistics
288+ final_stats = loaded_classifier .get_label_statistics ()
289+ print (f"\n Final Model Statistics:" )
290+ print (f"- Labels: { final_stats ['num_classes' ]} " )
291+ print (f"- Examples: { final_stats ['total_examples' ]} " )
292+ print (f"- Default threshold: { final_stats ['default_threshold' ]} " )
293+ print (f"- Adaptive threshold: { final_stats ['adaptive_threshold' ]:.3f} " )
294+
295+ except Exception as e :
296+ print (f"Error: { e } " )
297+ import traceback
298+ traceback .print_exc ()
299+
300+
301+ if __name__ == "__main__" :
302+ main ()
0 commit comments