Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ training_settings:
weight_decay: 0.01
report_to: "wandb"
logging_strategy: "epoch"
save_strategy: "epoch"
save_strategy: "no"
eval_strategy: "epoch"
save_total_limit: 1
class_weighting: False # if True a custom loss function will be used which aims to deal with label imbalance
output_weighting: False # if True custom thresholds will be set for the logits if False threshold = 0.5 for all
Expand Down
212 changes: 106 additions & 106 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from functools import partial
import json
import os

Expand All @@ -16,11 +17,13 @@
Trainer,
TrainingArguments,
)

import wandb
from utils import load_yaml_config


os.environ["WANDB_LOG_MODEL"] = "end"


def init_device():
"""Initialize device to use for training.

Expand Down Expand Up @@ -84,17 +87,18 @@ def __len__(self):
return len(self.labels)


def train(train_data, test_data, model_path, config, class_counts, class_weighting):
def train(train_data, test_data, model_path, config, class_counts, class_weighting, class_labels, test_counts):
"""Finetune a model from the config for the UKHRA data.

Args:
train_data (pd.DataFrame): Training dataframe
test_data (pd.DataFrame): Test dataframe
model_path (str): Path to save model.
config (dict): Configuration dictionary.

Returns:
dict: Evaluation metrics.
class_counts (list): List of class counts for train dataset.
class_weighting (bool): Whether to use class weighting in the loss function.
class_labels (list): List of class labels.
test_counts (list): List of class counts for test dataset.
"""
# tokenize data and create datasets
tokenizer = AutoTokenizer.from_pretrained(config["training_settings"]["model"])
Expand Down Expand Up @@ -162,9 +166,17 @@ def train(train_data, test_data, model_path, config, class_counts, class_weighti
save_total_limit=config["training_settings"]["save_total_limit"],
output_dir=model_path,
logging_strategy=config["training_settings"]["logging_strategy"],
eval_strategy=config["training_settings"]["eval_strategy"]
)

compute_metrics = prepare_compute_metrics(config)
compute_metrics_fn = partial(
compute_and_plot_metrics,
config=config,
class_labels=class_labels,
train_counts=class_counts,
test_counts=test_counts
)

# initialize trainer depending on class weighting option
if class_weighting:
total_count = sum(class_counts)
Expand All @@ -177,7 +189,7 @@ def train(train_data, test_data, model_path, config, class_counts, class_weighti
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=data_collator,
compute_metrics=compute_metrics,
compute_metrics=compute_metrics_fn,
class_weights=class_weights,
)
else:
Expand All @@ -187,128 +199,108 @@ def train(train_data, test_data, model_path, config, class_counts, class_weighti
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=data_collator,
compute_metrics=compute_metrics,
compute_metrics=compute_metrics_fn,
)

# train and evaluate
trainer.train()
metrics = trainer.evaluate()
trainer.evaluate()

# save model and tokenizer
tokenizer.save_pretrained(model_path + "/tokenizer")
trainer.save_model(output_dir=model_path)

return metrics


def prepare_compute_metrics(config):
"""Wrapper for compute_metrics so config can be accessed
def compute_and_plot_metrics(eval_pred, config, class_labels, train_counts, test_counts):
"""
Compute evaluatioon metrics and plot them in wandb.

This function is used as the compute_metrics function in the Trainer,
so it is called at the end of each evaluation phase during training.

Args: config (dict): Configuration dictionary from yaml file.
Args:
eval_pred (tuple): Tuple containing logits and labels.
config (dict): Configuration dictionary from yaml file.
class_labels (list): List of class labels.
train_counts (list): List of class counts for train dataset.
test_counts (list): List of class counts for test dataset.

Returns:
function: compute_metrics function
dict: Evaluation metrics (f1, f1_macro, f1_micro, precision, recall)
"""
metrics = compute_metrics(eval_pred=eval_pred, config=config)
table = plot_metrics(
metrics=metrics,
class_labels=class_labels,
train_counts=train_counts,
test_counts=test_counts
)
wandb.log({"metrics and value_count table": table})
return metrics

def compute_metrics(eval_pred):
"""Compute evaluation metrics to be used in the Trainer.

Args: eval_pred (tuple): Tuple containing logits and labels.

Returns:
dict: Evaluation metrics (f1, f1_macro, f1_micro, precision,
recall)
"""
logits, labels = eval_pred
# apply sigmoid to logits
logits = torch.sigmoid(torch.tensor(logits)).cpu().detach().numpy()

num_tags_predicted = []
if config["training_settings"]["output_weighting"]:
thresholds = [1] * labels.shape[1]
if (
config["training_settings"]["category"] == "RA"
or config["training_settings"]["category"] == "top_RA"
):
# make a list of increasing thresholds same length as the
# number of labels
thresholds[0] = 0.2
thresholds[1] = 0.5
thresholds[2] = 0.8
thresholds[3] = 0.95
else:
thresholds[0] = 0.2
thresholds[1] = 0.6
thresholds[2] = 0.8
thresholds[3] = 0.9

# Prepare an array to hold your predictions
predictions = np.zeros_like(logits)

# Loop through each sample's logits
for i, logit in enumerate(logits):
# Get the indices of the logits sorted by value in descending order
sorted_indices = np.argsort(logit)[::-1]

# Assign 1 to the top logits that exceed their respective thresholds
for rank, idx in enumerate(sorted_indices):
if logit[idx] > thresholds[rank]:
predictions[i, idx] = 1
num_tags_predicted.append(np.sum(predictions[i]))
else:
predictions = np.where(logits > 0.5, 1, 0)
for prediction in predictions:
num_tags_predicted.append(np.sum(prediction))
def compute_metrics(eval_pred, config):
"""Compute evaluation metrics to be used in the Trainer.

print(
"num_tags_predicted: ",
pd.Series(num_tags_predicted).value_counts().sort_index(),
)
log_data = pd.Series(num_tags_predicted).value_counts()
wandb.log({"num_tags_predicted": log_data.sort_index().to_json()})

# compute actual metrics
f1_macro = f1_score(labels, predictions, average="macro")
f1_micro = f1_score(labels, predictions, average="micro")
f1 = f1_score(labels, predictions, average=None)
precision = precision_score(labels, predictions, average=None)
recall = recall_score(labels, predictions, average=None)
metrics = {
"f1": f1,
"f1_macro": f1_macro,
"f1_micro": f1_micro,
"precision": precision,
"recall": recall,
}
print(metrics)
return metrics

return compute_metrics
Args:
eval_pred (tuple): Tuple containing logits and labels.
config (dict): Configuration dictionary from yaml file.

Returns:
dict: Evaluation metrics (f1, f1_macro, f1_micro, precision, recall)
"""
logits, labels = eval_pred
# apply sigmoid to logits
logits = torch.sigmoid(torch.tensor(logits)).cpu().detach().numpy()

if config["training_settings"]["output_weighting"]:
thresholds = [1] * labels.shape[1]
category = config["training_settings"].get("category")
if category in {"RA", "top_RA"}:
# thresholds tuned for RA/top_RA category
thresholds[:4] = [0.2, 0.5, 0.8, 0.95]
else:
thresholds[:4] = [0.2, 0.6, 0.8, 0.9]

predictions = np.zeros_like(logits)
for i, logit in enumerate(logits):
sorted_indices = np.argsort(logit)[::-1]
for rank, idx in enumerate(sorted_indices):
if rank < len(thresholds) and logit[idx] > thresholds[rank]:
predictions[i, idx] = 1
else:
predictions = np.where(logits > 0.5, 1, 0)

f1_macro = f1_score(labels, predictions, average="macro")
f1_micro = f1_score(labels, predictions, average="micro")
f1 = f1_score(labels, predictions, average=None)
precision = precision_score(labels, predictions, average=None)
recall = recall_score(labels, predictions, average=None)
metrics = {
"f1": f1,
"f1_macro": f1_macro,
"f1_micro": f1_micro,
"precision": precision,
"recall": recall,
}
return metrics

def plot_metrics(metrics, class_labels, train_counts, test_counts):
"""Plot evaluation metrics and value counts in wandb
"""Plot evaluation metrics and value counts in wandb.

Args:
metrics (dict): Evaluation metrics.
class_labels (list): List of class labels.
train_counts (list): List of class counts for train dataset.
test_counts (list): List of class counts for test dataset.

Returns:
wandb.Table: Table containing class labels, precision, recall, f1,
and value counts for train and test datasets.
"""
# pull in label names
with open(args.label_names_path) as f:
label_names = {k: v for line in f for k, v in json.loads(line).items()}

# add in the RA full name for reporting.
if "RA" in config["training_settings"]["category"]:
with open(args.label_names_path) as f:
label_names = {k: v for line in f for k, v in json.loads(line).items()}
class_labels = [f"{label}-{label_names[label]}" for label in class_labels]

f1 = metrics["eval_f1"]
precision = metrics["eval_precision"]
recall = metrics["eval_recall"]
f1 = metrics["f1"]
precision = metrics["precision"]
recall = metrics["recall"]
data = zip(
class_labels, precision, recall, f1, train_counts, test_counts, strict=False
)
Expand All @@ -317,8 +309,7 @@ def plot_metrics(metrics, class_labels, train_counts, test_counts):
data=[list(values) for values in data],
columns=["label", "precision", "recall", "f1", "train_count", "test_count"],
)

wandb.log({"metrics and value_count table": table})
return table


def run_training(args):
Expand All @@ -333,6 +324,15 @@ def run_training(args):
test_data = pd.read_parquet(args.test_path)

class_labels = list(train_data.columns[:-1])
with open(args.label_names_path) as f:
label_names = {k: v for line in f for k, v in json.loads(line).items()}

# add in the RA full name for reporting.
if "RA" in config["training_settings"]["category"]:
with open(args.label_names_path) as f:
label_names = {k: v for line in f for k, v in json.loads(line).items()}
class_labels = [f"{label}-{label_names[label]}" for label in class_labels]

train_counts = np.sum(train_data[train_data.columns[:-1]].to_numpy(), axis=0)
test_counts = np.sum(test_data[test_data.columns[:-1]].to_numpy(), axis=0)

Expand All @@ -344,17 +344,17 @@ def run_training(args):
wandb.log({"model_path": model_path})

class_weighting = config["training_settings"]["class_weighting"]
metrics = train(
train(
train_data,
test_data,
model_path=model_path,
config=config,
class_counts=train_counts,
class_weighting=class_weighting,
class_labels=class_labels,
test_counts=test_counts,
)

plot_metrics(metrics, class_labels, train_counts, test_counts)


if __name__ == "__main__":
# parse arguments
Expand Down