Skip to content

Commit 1740d91

Browse files
Add model evaluation example (#126)
Co-authored-by: Haihui.Wang <wanghh2000@163.com>
1 parent c577a2c commit 1740d91

1 file changed

Lines changed: 148 additions & 0 deletions

File tree

examples/model_eval.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import torch
2+
from transformers import BertTokenizer, BertForSequenceClassification
3+
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
4+
import pandas as pd
5+
import logging
6+
from pycsghub.snapshot_download import snapshot_download
7+
8+
# pip install -U scikit-learn
9+
10+
def DownloadModel(local_path: str):
11+
# token = "your access token"
12+
token = None
13+
endpoint = "https://hub.opencsg.com"
14+
repo_type = "model"
15+
repo_id = "wanghh2000/Erlangshen-RoBERTa-110M-Sentiment"
16+
local_dir = local_path
17+
18+
# set log level
19+
logging.basicConfig(
20+
level=getattr(logging, "INFO"),
21+
format='%(asctime)s - %(levelname)s - %(message)s',
22+
datefmt='%Y-%m-%d %H:%M:%S',
23+
handlers=[logging.StreamHandler()]
24+
)
25+
26+
result = snapshot_download(
27+
repo_id,
28+
repo_type=repo_type,
29+
local_dir=local_dir,
30+
endpoint=endpoint,
31+
token=token)
32+
33+
print(f"Save model to {result}")
34+
35+
class BERTEvaluator:
36+
def __init__(self, model_path):
37+
"""
38+
Initialize the Evaluator
39+
model_path: Path to the trained model
40+
"""
41+
# model_name='bert-base-chinese'
42+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43+
self.tokenizer = BertTokenizer.from_pretrained(model_path)
44+
self.model = BertForSequenceClassification.from_pretrained(model_path)
45+
print(f"Model loaded from {model_path}")
46+
self.model.to(self.device)
47+
self.model.eval()
48+
print(f"Set model to evaluation mode")
49+
50+
def predict_single(self, text, max_length=128):
51+
"""Predict a single text"""
52+
encoding = self.tokenizer(
53+
text,
54+
truncation=True,
55+
padding=True,
56+
max_length=max_length,
57+
return_tensors='pt'
58+
)
59+
60+
input_ids = encoding['input_ids'].to(self.device)
61+
attention_mask = encoding['attention_mask'].to(self.device)
62+
63+
with torch.no_grad():
64+
outputs = self.model(input_ids, attention_mask=attention_mask)
65+
logits = outputs.logits
66+
prediction = torch.argmax(logits, dim=-1).item()
67+
probabilities = torch.softmax(logits, dim=-1).cpu().numpy()[0]
68+
69+
return prediction, probabilities
70+
71+
def evaluate_batch(self, texts, true_labels, batch_size=16):
72+
"""Evaluate the model on a batch of texts"""
73+
all_predictions = []
74+
75+
# Process the batch of texts in batches
76+
for i in range(0, len(texts), batch_size):
77+
print(f"Processing batch {i//batch_size+1}/{len(texts)//batch_size+1}")
78+
batch_texts = texts[i:i+batch_size]
79+
# batch_labels = true_labels[i:i+batch_size]
80+
81+
# Encode the batch of texts
82+
encodings = self.tokenizer(
83+
batch_texts,
84+
truncation=True,
85+
padding=True,
86+
max_length=128,
87+
return_tensors='pt'
88+
)
89+
90+
input_ids = encodings['input_ids'].to(self.device)
91+
attention_mask = encodings['attention_mask'].to(self.device)
92+
93+
# Predict the batch of texts
94+
with torch.no_grad():
95+
outputs = self.model(input_ids, attention_mask=attention_mask)
96+
predictions = torch.argmax(outputs.logits, dim=-1)
97+
all_predictions.extend(predictions.cpu().numpy())
98+
99+
print(f"Processed {len(texts)} texts and started calculating metrics")
100+
# Calculate metrics
101+
accuracy = accuracy_score(true_labels, all_predictions)
102+
precision = precision_score(true_labels, all_predictions, average='binary')
103+
recall = recall_score(true_labels, all_predictions, average='binary')
104+
f1 = f1_score(true_labels, all_predictions, average='binary')
105+
106+
# Generate a detailed classification report
107+
report = classification_report(true_labels, all_predictions, target_names=['negative', 'positive'])
108+
109+
return {
110+
'accuracy': accuracy,
111+
'precision': precision,
112+
'recall': recall,
113+
'f1_score': f1,
114+
'predictions': all_predictions,
115+
'classification_report': report
116+
}
117+
118+
if __name__ == "__main__":
119+
model_path = "/Users/hhwang/temp/Erlangshen-RoBERTa-110M-Sentiment"
120+
DownloadModel(local_path=model_path)
121+
122+
# Initialize the Evaluator
123+
# model_path: Path to the trained model
124+
evaluator = BERTEvaluator(model_path=model_path)
125+
# Prepare test data
126+
test_texts = [
127+
"This movie is great!",
128+
"This service is terrible!",
129+
"The product quality is good, I recommend recommend it!",
130+
"The delivery is slow, I am not satisfied with it.",
131+
]
132+
test_labels = [1, 0, 1, 0] # 1: positive, 0: negative
133+
# Evaluate the model on the test data in batches
134+
results = evaluator.evaluate_batch(test_texts, test_labels)
135+
# Print the results
136+
print("\n" + "="*60)
137+
print("BERT Model Evaluation Results ")
138+
print("="*60)
139+
# Accuracy: The proportion of correct predictions among all predictions. Measures the overall correctness rate.
140+
print(f"Accuracy (Accuracy): {results['accuracy']:.4f}")
141+
# Among all samples predicted as positive by the model, how many are truly positive. Measures how accurate the model is when it predicts positive.
142+
print(f"Precision (Precision): {results['precision']:.4f}")
143+
# Among all samples that are truly positive, how many are correctly predicted as positive. Measures how complete the model is when it predicts positive.
144+
print(f"Recall (Recall): {results['recall']:.4f}")
145+
# Used to balance precision and recall. Measures whether the model's performance is balanced between positive and negative classes.
146+
print(f"F1-Score (F1-Score): {results['f1_score']:.4f}")
147+
print("\nDetailed Classification Report:")
148+
print(results['classification_report'])

0 commit comments

Comments
 (0)