Skip to content

Commit 62ee693

Browse files
committed
add a simple explainability example
1 parent 83108b2 commit 62ee693

1 file changed

Lines changed: 232 additions & 0 deletions

File tree

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""
2+
Simple Explainability Example with ASCII Visualization
3+
"""
4+
5+
import numpy as np
6+
import sys
7+
from torchTextClassifiers import create_fasttext
8+
9+
10+
def main():
11+
print("🔍 Simple Explainability Example")
12+
13+
# Enhanced training data with more diverse examples
14+
X_train = np.array([
15+
# Positive examples
16+
"I love this product",
17+
"Great quality and excellent service",
18+
"Amazing design and fantastic performance",
19+
"Outstanding value for money",
20+
"Excellent customer support team",
21+
"Love the innovative features",
22+
"Perfect solution for my needs",
23+
"Highly recommend this item",
24+
"Superb build quality",
25+
"Wonderful experience overall",
26+
"Great value and fast delivery",
27+
"Excellent product with amazing results",
28+
"Love this fantastic design",
29+
"Perfect quality and great price",
30+
"Amazing customer service experience",
31+
32+
# Negative examples
33+
"This is terrible quality",
34+
"Poor design and cheap materials",
35+
"Awful experience with this product",
36+
"Terrible customer service response",
37+
"Completely disappointing purchase",
38+
"Poor quality and overpriced item",
39+
"Awful build quality issues",
40+
"Terrible value for money",
41+
"Disappointing performance results",
42+
"Poor service and bad experience",
43+
"Awful design and cheap feel",
44+
"Terrible product with many issues",
45+
"Disappointing quality and poor value",
46+
"Bad experience with customer support",
47+
"Poor construction and awful materials"
48+
])
49+
50+
y_train = np.array([
51+
# Positive labels (1)
52+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
53+
# Negative labels (0)
54+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
55+
])
56+
57+
X_val = np.array([
58+
"Good product with decent quality",
59+
"Bad quality and poor service",
60+
"Excellent value and great design",
61+
"Terrible experience and awful quality"
62+
])
63+
y_val = np.array([1, 0, 1, 0])
64+
65+
# Create classifier
66+
classifier = create_fasttext(
67+
embedding_dim=50,
68+
sparse=False,
69+
num_tokens=1000,
70+
min_count=1,
71+
min_n=3,
72+
max_n=6,
73+
len_word_ngrams=2,
74+
num_classes=2,
75+
direct_bagging=False # Required for explainability
76+
)
77+
78+
# Train
79+
classifier.build(X_train, y_train)
80+
classifier.train(X_train, y_train, X_val, y_val, num_epochs=25, batch_size=8, verbose=False)
81+
82+
# Test examples with different sentiments
83+
test_texts = [
84+
"This product is amazing!",
85+
"Poor quality and terrible service",
86+
"Great value for money",
87+
"Completely disappointing and awful experience",
88+
"Love this excellent design"
89+
]
90+
91+
print(f"\n🔍 Testing explainability on {len(test_texts)} examples:")
92+
print("=" * 60)
93+
94+
for i, test_text in enumerate(test_texts, 1):
95+
print(f"\n📝 Example {i}:")
96+
print(f"Text: '{test_text}'")
97+
98+
# Get prediction
99+
prediction = classifier.predict(np.array([test_text]))[0]
100+
print(f"Prediction: {'Positive' if prediction == 1 else 'Negative'}")
101+
102+
# Get explainability scores
103+
try:
104+
pred, confidence, all_scores, all_scores_letters = classifier.predict_and_explain(np.array([test_text]))
105+
106+
# Create ASCII histogram
107+
if all_scores is not None and len(all_scores) > 0:
108+
scores_data = all_scores[0][0]
109+
if hasattr(scores_data, 'tolist'):
110+
scores = scores_data.tolist()
111+
else:
112+
scores = [float(scores_data)]
113+
114+
words = test_text.split()
115+
116+
if len(words) == len(scores):
117+
print("\n📊 Word Contribution Histogram:")
118+
print("-" * 50)
119+
120+
# Find max score for scaling
121+
max_score = max(scores) if scores else 1
122+
bar_width = 30 # max bar width in characters
123+
124+
for word, score in zip(words, scores):
125+
# Calculate bar length
126+
bar_length = int((score / max_score) * bar_width)
127+
bar = "█" * bar_length
128+
129+
# Format output
130+
print(f"{word:>12} | {bar:<30} {score:.4f}")
131+
132+
print("-" * 50)
133+
else:
134+
print(f"⚠️ Word/score mismatch: {len(words)} words vs {len(scores)} scores")
135+
else:
136+
print("⚠️ No explainability scores available")
137+
138+
except Exception as e:
139+
print(f"⚠️ Explainability failed: {e}")
140+
141+
# Analysis completed for this example
142+
print(f"✅ Analysis completed for example {i}")
143+
144+
print(f"\n🎉 Explainability analysis completed for {len(test_texts)} examples!")
145+
146+
# Interactive section for user input (only if --interactive flag is provided)
147+
if "--interactive" in sys.argv:
148+
print("\n" + "="*60)
149+
print("🎯 Interactive Explainability Mode")
150+
print("="*60)
151+
print("Enter your own text to see predictions and explanations!")
152+
print("Type 'quit' or 'exit' to end the session.\n")
153+
154+
while True:
155+
try:
156+
user_text = input("💬 Enter text: ").strip()
157+
158+
if user_text.lower() in ['quit', 'exit', 'q']:
159+
print("👋 Thanks for using the explainability tool!")
160+
break
161+
162+
if not user_text:
163+
print("⚠️ Please enter some text.")
164+
continue
165+
166+
print(f"\n🔍 Analyzing: '{user_text}'")
167+
168+
# Get prediction
169+
prediction = classifier.predict(np.array([user_text]))[0]
170+
sentiment = "Positive" if prediction == 1 else "Negative"
171+
print(f"🎯 Prediction: {sentiment}")
172+
173+
# Get explainability scores
174+
try:
175+
pred, confidence, all_scores, all_scores_letters = classifier.predict_and_explain(np.array([user_text]))
176+
177+
# Create ASCII histogram
178+
if all_scores is not None and len(all_scores) > 0:
179+
scores_data = all_scores[0][0]
180+
if hasattr(scores_data, 'tolist'):
181+
scores = scores_data.tolist()
182+
else:
183+
scores = [float(scores_data)]
184+
185+
words = user_text.split()
186+
187+
if len(words) == len(scores):
188+
print("\n📊 Word Contribution Histogram:")
189+
print("-" * 50)
190+
191+
# Find max score for scaling
192+
max_score = max(scores) if scores else 1
193+
bar_width = 30 # max bar width in characters
194+
195+
for word, score in zip(words, scores):
196+
# Calculate bar length
197+
bar_length = int((score / max_score) * bar_width)
198+
bar = "█" * bar_length
199+
200+
# Format output
201+
print(f"{word:>12} | {bar:<30} {score:.4f}")
202+
203+
print("-" * 50)
204+
205+
# Show interpretation
206+
top_word = max(zip(words, scores), key=lambda x: x[1])
207+
print(f"💡 Most influential word: '{top_word[0]}' (score: {top_word[1]:.4f})")
208+
209+
else:
210+
print(f"⚠️ Word/score mismatch: {len(words)} words vs {len(scores)} scores")
211+
else:
212+
print("⚠️ No explainability scores available")
213+
214+
except Exception as e:
215+
print(f"⚠️ Explainability failed: {e}")
216+
print("🔍 Prediction available, but detailed explanation unavailable.")
217+
218+
print("\n" + "-"*50)
219+
220+
except KeyboardInterrupt:
221+
print("\n👋 Session interrupted. Goodbye!")
222+
break
223+
except Exception as e:
224+
print(f"⚠️ Error: {e}")
225+
continue
226+
else:
227+
print("\n💡 Tip: Use --interactive flag to enter interactive mode for custom text analysis!")
228+
print(" Example: uv run python examples/simple_explainability_example.py --interactive")
229+
230+
231+
if __name__ == "__main__":
232+
main()

0 commit comments

Comments
 (0)