diff --git a/config/train_config.yaml b/config/train_config.yaml index 3077e59..40281ce 100644 --- a/config/train_config.yaml +++ b/config/train_config.yaml @@ -12,7 +12,8 @@ training_settings: weight_decay: 0.01 report_to: "wandb" logging_strategy: "epoch" - save_strategy: "epoch" + save_strategy: "no" + eval_strategy: "epoch" save_total_limit: 1 class_weighting: False # if True a custom loss function will be used which aims to deal with label imbalance output_weighting: False # if True custom thresholds will be set for the logits if False threshold = 0.5 for all diff --git a/src/train.py b/src/train.py index 3136d43..1461e22 100644 --- a/src/train.py +++ b/src/train.py @@ -1,4 +1,5 @@ import argparse +from functools import partial import json import os @@ -16,11 +17,13 @@ Trainer, TrainingArguments, ) - import wandb from utils import load_yaml_config +os.environ["WANDB_LOG_MODEL"] = "end" + + def init_device(): """Initialize device to use for training. @@ -84,7 +87,7 @@ def __len__(self): return len(self.labels) -def train(train_data, test_data, model_path, config, class_counts, class_weighting): +def train(train_data, test_data, model_path, config, class_counts, class_weighting, class_labels, test_counts): """Finetune a model from the config for the UKHRA data. Args: @@ -92,9 +95,10 @@ def train(train_data, test_data, model_path, config, class_counts, class_weighti test_data (pd.DataFrame): Test dataframe model_path (str): Path to save model. config (dict): Configuration dictionary. - - Returns: - dict: Evaluation metrics. + class_counts (list): List of class counts for train dataset. + class_weighting (bool): Whether to use class weighting in the loss function. + class_labels (list): List of class labels. + test_counts (list): List of class counts for test dataset. """ # tokenize data and create datasets tokenizer = AutoTokenizer.from_pretrained(config["training_settings"]["model"]) @@ -162,9 +166,17 @@ def train(train_data, test_data, model_path, config, class_counts, class_weighti save_total_limit=config["training_settings"]["save_total_limit"], output_dir=model_path, logging_strategy=config["training_settings"]["logging_strategy"], + eval_strategy=config["training_settings"]["eval_strategy"] ) - compute_metrics = prepare_compute_metrics(config) + compute_metrics_fn = partial( + compute_and_plot_metrics, + config=config, + class_labels=class_labels, + train_counts=class_counts, + test_counts=test_counts + ) + # initialize trainer depending on class weighting option if class_weighting: total_count = sum(class_counts) @@ -177,7 +189,7 @@ def train(train_data, test_data, model_path, config, class_counts, class_weighti train_dataset=train_dataset, eval_dataset=test_dataset, data_collator=data_collator, - compute_metrics=compute_metrics, + compute_metrics=compute_metrics_fn, class_weights=class_weights, ) else: @@ -187,107 +199,94 @@ def train(train_data, test_data, model_path, config, class_counts, class_weighti train_dataset=train_dataset, eval_dataset=test_dataset, data_collator=data_collator, - compute_metrics=compute_metrics, + compute_metrics=compute_metrics_fn, ) # train and evaluate trainer.train() - metrics = trainer.evaluate() + trainer.evaluate() # save model and tokenizer tokenizer.save_pretrained(model_path + "/tokenizer") trainer.save_model(output_dir=model_path) - return metrics - -def prepare_compute_metrics(config): - """Wrapper for compute_metrics so config can be accessed +def compute_and_plot_metrics(eval_pred, config, class_labels, train_counts, test_counts): + """ + Compute evaluatioon metrics and plot them in wandb. + + This function is used as the compute_metrics function in the Trainer, + so it is called at the end of each evaluation phase during training. - Args: config (dict): Configuration dictionary from yaml file. + Args: + eval_pred (tuple): Tuple containing logits and labels. + config (dict): Configuration dictionary from yaml file. + class_labels (list): List of class labels. + train_counts (list): List of class counts for train dataset. + test_counts (list): List of class counts for test dataset. Returns: - function: compute_metrics function + dict: Evaluation metrics (f1, f1_macro, f1_micro, precision, recall) """ + metrics = compute_metrics(eval_pred=eval_pred, config=config) + table = plot_metrics( + metrics=metrics, + class_labels=class_labels, + train_counts=train_counts, + test_counts=test_counts + ) + wandb.log({"metrics and value_count table": table}) + return metrics - def compute_metrics(eval_pred): - """Compute evaluation metrics to be used in the Trainer. - - Args: eval_pred (tuple): Tuple containing logits and labels. - Returns: - dict: Evaluation metrics (f1, f1_macro, f1_micro, precision, - recall) - """ - logits, labels = eval_pred - # apply sigmoid to logits - logits = torch.sigmoid(torch.tensor(logits)).cpu().detach().numpy() - - num_tags_predicted = [] - if config["training_settings"]["output_weighting"]: - thresholds = [1] * labels.shape[1] - if ( - config["training_settings"]["category"] == "RA" - or config["training_settings"]["category"] == "top_RA" - ): - # make a list of increasing thresholds same length as the - # number of labels - thresholds[0] = 0.2 - thresholds[1] = 0.5 - thresholds[2] = 0.8 - thresholds[3] = 0.95 - else: - thresholds[0] = 0.2 - thresholds[1] = 0.6 - thresholds[2] = 0.8 - thresholds[3] = 0.9 - - # Prepare an array to hold your predictions - predictions = np.zeros_like(logits) - - # Loop through each sample's logits - for i, logit in enumerate(logits): - # Get the indices of the logits sorted by value in descending order - sorted_indices = np.argsort(logit)[::-1] - - # Assign 1 to the top logits that exceed their respective thresholds - for rank, idx in enumerate(sorted_indices): - if logit[idx] > thresholds[rank]: - predictions[i, idx] = 1 - num_tags_predicted.append(np.sum(predictions[i])) - else: - predictions = np.where(logits > 0.5, 1, 0) - for prediction in predictions: - num_tags_predicted.append(np.sum(prediction)) +def compute_metrics(eval_pred, config): + """Compute evaluation metrics to be used in the Trainer. - print( - "num_tags_predicted: ", - pd.Series(num_tags_predicted).value_counts().sort_index(), - ) - log_data = pd.Series(num_tags_predicted).value_counts() - wandb.log({"num_tags_predicted": log_data.sort_index().to_json()}) - - # compute actual metrics - f1_macro = f1_score(labels, predictions, average="macro") - f1_micro = f1_score(labels, predictions, average="micro") - f1 = f1_score(labels, predictions, average=None) - precision = precision_score(labels, predictions, average=None) - recall = recall_score(labels, predictions, average=None) - metrics = { - "f1": f1, - "f1_macro": f1_macro, - "f1_micro": f1_micro, - "precision": precision, - "recall": recall, - } - print(metrics) - return metrics - - return compute_metrics + Args: + eval_pred (tuple): Tuple containing logits and labels. + config (dict): Configuration dictionary from yaml file. + Returns: + dict: Evaluation metrics (f1, f1_macro, f1_micro, precision, recall) + """ + logits, labels = eval_pred + # apply sigmoid to logits + logits = torch.sigmoid(torch.tensor(logits)).cpu().detach().numpy() + + if config["training_settings"]["output_weighting"]: + thresholds = [1] * labels.shape[1] + category = config["training_settings"].get("category") + if category in {"RA", "top_RA"}: + # thresholds tuned for RA/top_RA category + thresholds[:4] = [0.2, 0.5, 0.8, 0.95] + else: + thresholds[:4] = [0.2, 0.6, 0.8, 0.9] + + predictions = np.zeros_like(logits) + for i, logit in enumerate(logits): + sorted_indices = np.argsort(logit)[::-1] + for rank, idx in enumerate(sorted_indices): + if rank < len(thresholds) and logit[idx] > thresholds[rank]: + predictions[i, idx] = 1 + else: + predictions = np.where(logits > 0.5, 1, 0) + + f1_macro = f1_score(labels, predictions, average="macro") + f1_micro = f1_score(labels, predictions, average="micro") + f1 = f1_score(labels, predictions, average=None) + precision = precision_score(labels, predictions, average=None) + recall = recall_score(labels, predictions, average=None) + metrics = { + "f1": f1, + "f1_macro": f1_macro, + "f1_micro": f1_micro, + "precision": precision, + "recall": recall, + } + return metrics def plot_metrics(metrics, class_labels, train_counts, test_counts): - """Plot evaluation metrics and value counts in wandb + """Plot evaluation metrics and value counts in wandb. Args: metrics (dict): Evaluation metrics. @@ -295,20 +294,13 @@ def plot_metrics(metrics, class_labels, train_counts, test_counts): train_counts (list): List of class counts for train dataset. test_counts (list): List of class counts for test dataset. + Returns: + wandb.Table: Table containing class labels, precision, recall, f1, + and value counts for train and test datasets. """ - # pull in label names - with open(args.label_names_path) as f: - label_names = {k: v for line in f for k, v in json.loads(line).items()} - - # add in the RA full name for reporting. - if "RA" in config["training_settings"]["category"]: - with open(args.label_names_path) as f: - label_names = {k: v for line in f for k, v in json.loads(line).items()} - class_labels = [f"{label}-{label_names[label]}" for label in class_labels] - - f1 = metrics["eval_f1"] - precision = metrics["eval_precision"] - recall = metrics["eval_recall"] + f1 = metrics["f1"] + precision = metrics["precision"] + recall = metrics["recall"] data = zip( class_labels, precision, recall, f1, train_counts, test_counts, strict=False ) @@ -317,8 +309,7 @@ def plot_metrics(metrics, class_labels, train_counts, test_counts): data=[list(values) for values in data], columns=["label", "precision", "recall", "f1", "train_count", "test_count"], ) - - wandb.log({"metrics and value_count table": table}) + return table def run_training(args): @@ -333,6 +324,15 @@ def run_training(args): test_data = pd.read_parquet(args.test_path) class_labels = list(train_data.columns[:-1]) + with open(args.label_names_path) as f: + label_names = {k: v for line in f for k, v in json.loads(line).items()} + + # add in the RA full name for reporting. + if "RA" in config["training_settings"]["category"]: + with open(args.label_names_path) as f: + label_names = {k: v for line in f for k, v in json.loads(line).items()} + class_labels = [f"{label}-{label_names[label]}" for label in class_labels] + train_counts = np.sum(train_data[train_data.columns[:-1]].to_numpy(), axis=0) test_counts = np.sum(test_data[test_data.columns[:-1]].to_numpy(), axis=0) @@ -344,17 +344,17 @@ def run_training(args): wandb.log({"model_path": model_path}) class_weighting = config["training_settings"]["class_weighting"] - metrics = train( + train( train_data, test_data, model_path=model_path, config=config, class_counts=train_counts, class_weighting=class_weighting, + class_labels=class_labels, + test_counts=test_counts, ) - plot_metrics(metrics, class_labels, train_counts, test_counts) - if __name__ == "__main__": # parse arguments