Skip to content

Commit db670f5

Browse files
committed
add impl
1 parent 0202d37 commit db670f5

4 files changed

Lines changed: 982 additions & 0 deletions

File tree

examples/multilabel_usage.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Multi-Label Adaptive Classifier Example
4+
5+
This example demonstrates how to use the MultiLabelAdaptiveClassifier
6+
for text classification tasks where each text can belong to multiple categories.
7+
8+
Key features demonstrated:
9+
1. Training with multi-label data
10+
2. Making multi-label predictions
11+
3. Adaptive threshold handling for many labels
12+
4. Label-specific threshold customization
13+
5. Saving and loading multi-label models
14+
"""
15+
16+
import sys
17+
import os
18+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
19+
20+
from adaptive_classifier import MultiLabelAdaptiveClassifier
21+
import torch
22+
23+
24+
def create_sample_data():
25+
"""Create sample multi-label training data."""
26+
27+
# Sample texts with multiple labels each
28+
training_data = [
29+
{
30+
"text": "Scientists discover new species of butterfly in Amazon rainforest with unique wing patterns",
31+
"labels": ["science", "nature", "discovery", "biology"]
32+
},
33+
{
34+
"text": "Tech startup raises $50M in Series A funding to develop AI-powered healthcare solutions",
35+
"labels": ["technology", "business", "healthcare", "funding"]
36+
},
37+
{
38+
"text": "Climate change impacts ocean temperature causing coral bleaching in Great Barrier Reef",
39+
"labels": ["environment", "climate", "nature", "science"]
40+
},
41+
{
42+
"text": "NBA playoffs feature exciting games with record-breaking performances by star players",
43+
"labels": ["sports", "entertainment", "basketball"]
44+
},
45+
{
46+
"text": "New renewable energy technology could reduce costs by 40% according to MIT research",
47+
"labels": ["technology", "science", "environment", "energy"]
48+
},
49+
{
50+
"text": "Archaeological team uncovers 2000-year-old Roman artifacts in excavation site",
51+
"labels": ["history", "science", "discovery", "archaeology"]
52+
},
53+
{
54+
"text": "Stock market reaches new highs as investors show confidence in economic recovery",
55+
"labels": ["business", "finance", "economy"]
56+
},
57+
{
58+
"text": "Machine learning breakthrough helps doctors diagnose rare diseases more accurately",
59+
"labels": ["technology", "healthcare", "science", "ai"]
60+
},
61+
{
62+
"text": "Wildlife conservation efforts show success in protecting endangered tiger populations",
63+
"labels": ["nature", "environment", "conservation", "wildlife"]
64+
},
65+
{
66+
"text": "Olympic athletes prepare for upcoming games with intensive training programs",
67+
"labels": ["sports", "olympics", "training", "fitness"]
68+
},
69+
{
70+
"text": "Quantum computing research makes progress toward solving complex optimization problems",
71+
"labels": ["technology", "science", "computing", "research"]
72+
},
73+
{
74+
"text": "Sustainable agriculture practices help farmers reduce environmental impact while increasing yield",
75+
"labels": ["environment", "agriculture", "sustainability", "farming"]
76+
},
77+
{
78+
"text": "Music festival features artists from diverse genres attracting thousands of fans",
79+
"labels": ["entertainment", "music", "culture", "events"]
80+
},
81+
{
82+
"text": "Space agency announces plans for Mars mission with new rocket technology",
83+
"labels": ["science", "space", "technology", "exploration"]
84+
},
85+
{
86+
"text": "Educational technology helps students learn programming through interactive online courses",
87+
"labels": ["education", "technology", "programming", "learning"]
88+
}
89+
]
90+
91+
# Extract texts and labels
92+
texts = [item["text"] for item in training_data]
93+
labels = [item["labels"] for item in training_data]
94+
95+
return texts, labels
96+
97+
98+
def demonstrate_basic_usage():
99+
"""Demonstrate basic multi-label classification."""
100+
101+
print("=" * 60)
102+
print("MULTI-LABEL ADAPTIVE CLASSIFIER - BASIC USAGE")
103+
print("=" * 60)
104+
105+
# Create classifier
106+
classifier = MultiLabelAdaptiveClassifier(
107+
model_name="distilbert/distilbert-base-cased",
108+
default_threshold=0.5,
109+
min_predictions=1, # Ensure at least 1 prediction
110+
max_predictions=5 # Limit to top 5 predictions
111+
)
112+
113+
# Load training data
114+
texts, labels = create_sample_data()
115+
116+
print(f"Training with {len(texts)} examples")
117+
print(f"Example text: {texts[0][:60]}...")
118+
print(f"Example labels: {labels[0]}")
119+
120+
# Train the classifier
121+
classifier.add_examples(texts, labels)
122+
123+
# Get statistics
124+
stats = classifier.get_label_statistics()
125+
print(f"\nTraining completed:")
126+
print(f"- Total labels: {stats['num_classes']}")
127+
print(f"- Total examples: {stats['total_examples']}")
128+
print(f"- Adaptive threshold: {stats['adaptive_threshold']:.3f}")
129+
130+
return classifier
131+
132+
133+
def demonstrate_predictions(classifier):
134+
"""Demonstrate making predictions."""
135+
136+
print("\n" + "=" * 60)
137+
print("MAKING PREDICTIONS")
138+
print("=" * 60)
139+
140+
# Test texts
141+
test_texts = [
142+
"Researchers develop new AI algorithm for medical diagnosis",
143+
"Football team wins championship in exciting final match",
144+
"Solar panel efficiency increases with new manufacturing technique",
145+
"Ancient civilization discovered through satellite imagery analysis"
146+
]
147+
148+
for text in test_texts:
149+
print(f"\nText: {text}")
150+
151+
# Make multi-label prediction
152+
predictions = classifier.predict_multilabel(text)
153+
154+
print("Predictions:")
155+
if predictions:
156+
for label, confidence in predictions:
157+
print(f" {label}: {confidence:.4f}")
158+
else:
159+
print(" No predictions above threshold")
160+
161+
return test_texts
162+
163+
164+
def demonstrate_threshold_adjustment(classifier):
165+
"""Demonstrate threshold adjustment for different scenarios."""
166+
167+
print("\n" + "=" * 60)
168+
print("THRESHOLD ADJUSTMENT")
169+
print("=" * 60)
170+
171+
test_text = "AI researchers publish breakthrough study on climate modeling using machine learning"
172+
173+
print(f"Test text: {test_text}")
174+
175+
# Try different thresholds
176+
thresholds = [0.1, 0.3, 0.5, 0.7, 0.9]
177+
178+
print(f"\n{'Threshold':<10} {'Predictions':<12} {'Labels'}")
179+
print("-" * 50)
180+
181+
for threshold in thresholds:
182+
predictions = classifier.predict_multilabel(test_text, threshold=threshold)
183+
labels_str = ", ".join([label for label, _ in predictions[:3]])
184+
185+
print(f"{threshold:<10.1f} {len(predictions):<12} {labels_str}")
186+
187+
188+
def demonstrate_saving_loading(classifier):
189+
"""Demonstrate saving and loading the model."""
190+
191+
print("\n" + "=" * 60)
192+
print("SAVING AND LOADING")
193+
print("=" * 60)
194+
195+
# Save the model
196+
save_path = "./multilabel_classifier"
197+
print(f"Saving classifier to {save_path}")
198+
classifier.save(save_path)
199+
200+
# Load the model
201+
print("Loading classifier...")
202+
loaded_classifier = MultiLabelAdaptiveClassifier.load(save_path)
203+
204+
# Verify it works
205+
test_text = "New medical technology helps treat cancer patients"
206+
207+
print(f"\nTesting loaded classifier:")
208+
print(f"Text: {test_text}")
209+
210+
predictions = loaded_classifier.predict_multilabel(test_text)
211+
print("Predictions:")
212+
for label, confidence in predictions:
213+
print(f" {label}: {confidence:.4f}")
214+
215+
return loaded_classifier
216+
217+
218+
def demonstrate_incremental_learning(classifier):
219+
"""Demonstrate adding new labels incrementally."""
220+
221+
print("\n" + "=" * 60)
222+
print("INCREMENTAL LEARNING - ADDING NEW LABELS")
223+
print("=" * 60)
224+
225+
# Add new examples with new labels
226+
new_texts = [
227+
"Chef creates innovative fusion cuisine combining Asian and European flavors",
228+
"Food delivery service expands to new cities with sustainable packaging",
229+
"Restaurant industry adapts to new dining trends post-pandemic",
230+
"Cooking show features celebrity chefs competing in culinary challenges"
231+
]
232+
233+
new_labels = [
234+
["food", "cuisine", "cooking", "culture"],
235+
["business", "food", "sustainability"],
236+
["business", "food", "trends"],
237+
["entertainment", "food", "cooking", "tv"]
238+
]
239+
240+
print("Adding new examples with 'food' and 'cooking' labels...")
241+
classifier.add_examples(new_texts, new_labels)
242+
243+
# Test with food-related text
244+
food_text = "Nutritionist recommends healthy meal planning for busy professionals"
245+
246+
print(f"\nTesting with food-related text:")
247+
print(f"Text: {food_text}")
248+
249+
predictions = classifier.predict_multilabel(food_text)
250+
print("Predictions:")
251+
for label, confidence in predictions:
252+
print(f" {label}: {confidence:.4f}")
253+
254+
# Show updated statistics
255+
stats = classifier.get_label_statistics()
256+
print(f"\nUpdated statistics:")
257+
print(f"- Total labels: {stats['num_classes']}")
258+
print(f"- Total examples: {stats['total_examples']}")
259+
260+
261+
def main():
262+
"""Main function to run all demonstrations."""
263+
264+
print("Multi-Label Adaptive Classifier Example")
265+
print("Fixing the 'No labels met the threshold criteria' issue\n")
266+
267+
try:
268+
# Basic usage
269+
classifier = demonstrate_basic_usage()
270+
271+
# Making predictions
272+
demonstrate_predictions(classifier)
273+
274+
# Threshold adjustment
275+
demonstrate_threshold_adjustment(classifier)
276+
277+
# Saving and loading
278+
loaded_classifier = demonstrate_saving_loading(classifier)
279+
280+
# Incremental learning
281+
demonstrate_incremental_learning(loaded_classifier)
282+
283+
print("\n" + "=" * 60)
284+
print("EXAMPLE COMPLETED SUCCESSFULLY")
285+
print("=" * 60)
286+
287+
# Final statistics
288+
final_stats = loaded_classifier.get_label_statistics()
289+
print(f"\nFinal Model Statistics:")
290+
print(f"- Labels: {final_stats['num_classes']}")
291+
print(f"- Examples: {final_stats['total_examples']}")
292+
print(f"- Default threshold: {final_stats['default_threshold']}")
293+
print(f"- Adaptive threshold: {final_stats['adaptive_threshold']:.3f}")
294+
295+
except Exception as e:
296+
print(f"Error: {e}")
297+
import traceback
298+
traceback.print_exc()
299+
300+
301+
if __name__ == "__main__":
302+
main()

src/adaptive_classifier/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from .classifier import AdaptiveClassifier
22
from .models import Example, AdaptiveHead, ModelConfig
33
from .memory import PrototypeMemory
4+
from .multilabel import MultiLabelAdaptiveClassifier, MultiLabelAdaptiveHead
45
from huggingface_hub import ModelHubMixin
56

67
__version__ = "0.0.17"
78

89
__all__ = [
910
"AdaptiveClassifier",
11+
"MultiLabelAdaptiveClassifier",
12+
"MultiLabelAdaptiveHead",
1013
"Example",
1114
"AdaptiveHead",
1215
"ModelConfig",

0 commit comments

Comments
 (0)