-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathutils.py
More file actions
19 lines (15 loc) · 759 Bytes
/
utils.py
File metadata and controls
19 lines (15 loc) · 759 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy as np
import evaluate
# Load the "accuracy" module from the evaluate library.
accuracy = evaluate.load("accuracy")
# Create a preprocessing function to encode text and truncate strings longer than the maximum input token length.
def preprocess_function(tokenizer, examples):
samples = tokenizer(examples["text"], truncation=True)
samples.pop('attention_mask')
return samples
def compute_metrics(eval_pred):
predictions, labels = eval_pred
# Get the index of the class with the highest probability in predictions.
predictions = np.argmax(predictions, axis=1)
# Use the "accuracy" module to compute accuracy based on predictions and labels.
return accuracy.compute(predictions=predictions, references=labels)