Source code for seqgra.evaluator.rocevaluator

"""
MIT - CSAIL - Gifford Lab - seqgra

ROC evaluator: creates ROC curves

@author: Konstantin Krismer
"""
from typing import Any, List

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from scipy import interp

import seqgra.constants as c
from seqgra.learner import Learner
from seqgra.evaluator import Evaluator


[docs]class ROCEvaluator(Evaluator): def __init__(self, learner: Learner, output_dir: str, silent: bool = False) -> None: super().__init__(c.EvaluatorID.ROC, "ROC curve", learner, output_dir, silent=silent) def _evaluate_model(self, x: List[str], y: List[str], annotations: List[str]) -> Any: encoded_y = self.learner.encode_y(y) y_hat = self.learner.predict(x=x) return (encoded_y, y_hat) def _save_results(self, results, set_name: str = "test", suppress_plots: bool = False) -> None: if not suppress_plots: self.create_roc_curve(results[0], results[1], self.output_dir + set_name + "-roc-curve.pdf")
[docs] def create_roc_curve(self, y_true, y_hat, file_name) -> None: """Create ROC curve. Plots ROC curves for each class label, including micro-average and macro-average. Saves plot as PDF in `file_name`. Arguments: y_true (array): TODO ; shape = [n_samples, n_classes] y_hat (array): TODO ; shape = [n_samples, n_classes] file_name (str): TODO """ fpr = dict() tpr = dict() roc_auc = dict() n_classes = len(self.learner.definition.labels) for i in range(n_classes): fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_hat[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) # Compute micro-average ROC curve and ROC area fpr["micro"], tpr["micro"], _ = roc_curve( y_true.ravel(), y_hat.ravel()) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) # Compute macro-average ROC curve and ROC area # First aggregate all false positive rates all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) # Then interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(n_classes): mean_tpr += interp(all_fpr, fpr[i], tpr[i]) # Finally average it and compute AUC mean_tpr /= n_classes fpr["macro"] = all_fpr tpr["macro"] = mean_tpr roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) # Plot all ROC curves plt.figure(figsize=(7, 7)) lines = [] labels = [] line, = plt.plot(fpr["micro"], tpr["micro"], color="gold", linestyle=":", linewidth=2) lines.append(line) labels.append("micro-average (area = {0:0.2f})" "".format(roc_auc["micro"])) line, = plt.plot(fpr["macro"], tpr["macro"], color="darkorange", linestyle=":", linewidth=2) lines.append(line) labels.append("macro-average (area = {0:0.2f})" "".format(roc_auc["macro"])) for i in range(n_classes): line, = plt.plot(fpr[i], tpr[i], linewidth=2) lines.append(line) labels.append( "condition {0} (area = {1:0.2f})" "".format(self.learner.definition.labels[i], roc_auc[i])) plt.plot([0, 1], [0, 1], "k--", linewidth=2) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.title("ROC curve") plt.legend(lines, labels, bbox_to_anchor=(1.04, 1), loc="upper left", prop=dict(size=14)) plt.savefig(file_name, bbox_inches="tight")