-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
120 lines (101 loc) · 3.24 KB
/
model.py
File metadata and controls
120 lines (101 loc) · 3.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Sentiment analysis model module using DistilBERT via HuggingFace Transformers.
Model: distilbert-base-uncased-finetuned-sst-2-english
"""
from transformers import pipeline
EMOJI_MAP = {
"POSITIVE": "😊",
"NEGATIVE": "😞",
}
# Load model at import time so the first user request is instant.
# HuggingFace caches the weights locally after the first download.
print("Loading DistilBERT model…", flush=True)
_pipeline = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=-1, # CPU
)
print("Model ready.", flush=True)
def _get_pipeline():
"""Return the pre-loaded sentiment-analysis pipeline."""
return _pipeline
def analyze_text(text: str) -> dict:
"""
Analyze the sentiment of a single text string.
Args:
text: The input text to classify.
Returns:
A dict with keys:
label (str): "POSITIVE" or "NEGATIVE"
score (float): Confidence score in [0, 1]
emoji (str): Corresponding emoji for the label
error (str | None): Error message if something went wrong, else None
"""
if not text or not str(text).strip():
return {
"label": "ERROR",
"score": 0.0,
"emoji": "❓",
"error": "Input text cannot be empty. Please enter some text to analyze.",
}
try:
pipe = _get_pipeline()
# Truncate to 512 characters to stay within DistilBERT token limit
raw = pipe(text.strip()[:512])[0]
label = raw["label"]
score = raw["score"]
return {
"label": label,
"score": score,
"emoji": EMOJI_MAP.get(label, "❓"),
"error": None,
}
except Exception as exc:
return {
"label": "ERROR",
"score": 0.0,
"emoji": "❓",
"error": f"Inference failed: {exc}",
}
def analyze_batch(texts: list) -> list:
"""
Analyze the sentiment of a list of text strings.
Processes each text individually to give per-item error isolation.
Empty or None entries are marked as SKIPPED rather than crashing the batch.
Args:
texts: A list of text strings to classify.
Returns:
A list of dicts, one per input text. Each dict has the same keys
as the return value of analyze_text().
"""
if not texts:
return []
pipe = _get_pipeline()
results = []
for text in texts:
if not text or not str(text).strip():
results.append({
"label": "SKIPPED",
"score": 0.0,
"emoji": "❓",
"error": "Empty or missing text",
})
continue
try:
raw = pipe(str(text).strip()[:512])[0]
label = raw["label"]
score = raw["score"]
results.append({
"label": label,
"score": score,
"emoji": EMOJI_MAP.get(label, "❓"),
"error": None,
})
except Exception as exc:
results.append({
"label": "ERROR",
"score": 0.0,
"emoji": "❓",
"error": str(exc),
})
return results