|
| 1 | +import torch |
| 2 | +from transformers import BertTokenizer, BertForSequenceClassification |
| 3 | +from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score |
| 4 | +import pandas as pd |
| 5 | +import logging |
| 6 | +from pycsghub.snapshot_download import snapshot_download |
| 7 | + |
| 8 | +# pip install -U scikit-learn |
| 9 | + |
| 10 | +def DownloadModel(local_path: str): |
| 11 | + # token = "your access token" |
| 12 | + token = None |
| 13 | + endpoint = "https://hub.opencsg.com" |
| 14 | + repo_type = "model" |
| 15 | + repo_id = "wanghh2000/Erlangshen-RoBERTa-110M-Sentiment" |
| 16 | + local_dir = local_path |
| 17 | + |
| 18 | + # set log level |
| 19 | + logging.basicConfig( |
| 20 | + level=getattr(logging, "INFO"), |
| 21 | + format='%(asctime)s - %(levelname)s - %(message)s', |
| 22 | + datefmt='%Y-%m-%d %H:%M:%S', |
| 23 | + handlers=[logging.StreamHandler()] |
| 24 | + ) |
| 25 | + |
| 26 | + result = snapshot_download( |
| 27 | + repo_id, |
| 28 | + repo_type=repo_type, |
| 29 | + local_dir=local_dir, |
| 30 | + endpoint=endpoint, |
| 31 | + token=token) |
| 32 | + |
| 33 | + print(f"Save model to {result}") |
| 34 | + |
| 35 | +class BERTEvaluator: |
| 36 | + def __init__(self, model_path): |
| 37 | + """ |
| 38 | + Initialize the Evaluator |
| 39 | + model_path: Path to the trained model |
| 40 | + """ |
| 41 | + # model_name='bert-base-chinese' |
| 42 | + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 43 | + self.tokenizer = BertTokenizer.from_pretrained(model_path) |
| 44 | + self.model = BertForSequenceClassification.from_pretrained(model_path) |
| 45 | + print(f"Model loaded from {model_path}") |
| 46 | + self.model.to(self.device) |
| 47 | + self.model.eval() |
| 48 | + print(f"Set model to evaluation mode") |
| 49 | + |
| 50 | + def predict_single(self, text, max_length=128): |
| 51 | + """Predict a single text""" |
| 52 | + encoding = self.tokenizer( |
| 53 | + text, |
| 54 | + truncation=True, |
| 55 | + padding=True, |
| 56 | + max_length=max_length, |
| 57 | + return_tensors='pt' |
| 58 | + ) |
| 59 | + |
| 60 | + input_ids = encoding['input_ids'].to(self.device) |
| 61 | + attention_mask = encoding['attention_mask'].to(self.device) |
| 62 | + |
| 63 | + with torch.no_grad(): |
| 64 | + outputs = self.model(input_ids, attention_mask=attention_mask) |
| 65 | + logits = outputs.logits |
| 66 | + prediction = torch.argmax(logits, dim=-1).item() |
| 67 | + probabilities = torch.softmax(logits, dim=-1).cpu().numpy()[0] |
| 68 | + |
| 69 | + return prediction, probabilities |
| 70 | + |
| 71 | + def evaluate_batch(self, texts, true_labels, batch_size=16): |
| 72 | + """Evaluate the model on a batch of texts""" |
| 73 | + all_predictions = [] |
| 74 | + |
| 75 | + # Process the batch of texts in batches |
| 76 | + for i in range(0, len(texts), batch_size): |
| 77 | + print(f"Processing batch {i//batch_size+1}/{len(texts)//batch_size+1}") |
| 78 | + batch_texts = texts[i:i+batch_size] |
| 79 | + # batch_labels = true_labels[i:i+batch_size] |
| 80 | + |
| 81 | + # Encode the batch of texts |
| 82 | + encodings = self.tokenizer( |
| 83 | + batch_texts, |
| 84 | + truncation=True, |
| 85 | + padding=True, |
| 86 | + max_length=128, |
| 87 | + return_tensors='pt' |
| 88 | + ) |
| 89 | + |
| 90 | + input_ids = encodings['input_ids'].to(self.device) |
| 91 | + attention_mask = encodings['attention_mask'].to(self.device) |
| 92 | + |
| 93 | + # Predict the batch of texts |
| 94 | + with torch.no_grad(): |
| 95 | + outputs = self.model(input_ids, attention_mask=attention_mask) |
| 96 | + predictions = torch.argmax(outputs.logits, dim=-1) |
| 97 | + all_predictions.extend(predictions.cpu().numpy()) |
| 98 | + |
| 99 | + print(f"Processed {len(texts)} texts and started calculating metrics") |
| 100 | + # Calculate metrics |
| 101 | + accuracy = accuracy_score(true_labels, all_predictions) |
| 102 | + precision = precision_score(true_labels, all_predictions, average='binary') |
| 103 | + recall = recall_score(true_labels, all_predictions, average='binary') |
| 104 | + f1 = f1_score(true_labels, all_predictions, average='binary') |
| 105 | + |
| 106 | + # Generate a detailed classification report |
| 107 | + report = classification_report(true_labels, all_predictions, target_names=['negative', 'positive']) |
| 108 | + |
| 109 | + return { |
| 110 | + 'accuracy': accuracy, |
| 111 | + 'precision': precision, |
| 112 | + 'recall': recall, |
| 113 | + 'f1_score': f1, |
| 114 | + 'predictions': all_predictions, |
| 115 | + 'classification_report': report |
| 116 | + } |
| 117 | + |
| 118 | +if __name__ == "__main__": |
| 119 | + model_path = "/Users/hhwang/temp/Erlangshen-RoBERTa-110M-Sentiment" |
| 120 | + DownloadModel(local_path=model_path) |
| 121 | + |
| 122 | + # Initialize the Evaluator |
| 123 | + # model_path: Path to the trained model |
| 124 | + evaluator = BERTEvaluator(model_path=model_path) |
| 125 | + # Prepare test data |
| 126 | + test_texts = [ |
| 127 | + "This movie is great!", |
| 128 | + "This service is terrible!", |
| 129 | + "The product quality is good, I recommend recommend it!", |
| 130 | + "The delivery is slow, I am not satisfied with it.", |
| 131 | + ] |
| 132 | + test_labels = [1, 0, 1, 0] # 1: positive, 0: negative |
| 133 | + # Evaluate the model on the test data in batches |
| 134 | + results = evaluator.evaluate_batch(test_texts, test_labels) |
| 135 | + # Print the results |
| 136 | + print("\n" + "="*60) |
| 137 | + print("BERT Model Evaluation Results ") |
| 138 | + print("="*60) |
| 139 | + # Accuracy: The proportion of correct predictions among all predictions. Measures the overall correctness rate. |
| 140 | + print(f"Accuracy (Accuracy): {results['accuracy']:.4f}") |
| 141 | + # Among all samples predicted as positive by the model, how many are truly positive. Measures how accurate the model is when it predicts positive. |
| 142 | + print(f"Precision (Precision): {results['precision']:.4f}") |
| 143 | + # Among all samples that are truly positive, how many are correctly predicted as positive. Measures how complete the model is when it predicts positive. |
| 144 | + print(f"Recall (Recall): {results['recall']:.4f}") |
| 145 | + # Used to balance precision and recall. Measures whether the model's performance is balanced between positive and negative classes. |
| 146 | + print(f"F1-Score (F1-Score): {results['f1_score']:.4f}") |
| 147 | + print("\nDetailed Classification Report:") |
| 148 | + print(results['classification_report']) |
0 commit comments