from pathlib import Path
import pickle
from mhcpred.config import settings
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix, balanced_accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
= Path(settings.models_path)
models_path = Path(settings.output_path) output_path
This notebook contains training metrics history and classification metrics computed on the predictions by - mhcflurry (benchmark) - mhcpred
In [25]:
Information on the training history
I prefer to use tensorboard, but it is not implemented in the mhcflurry package. The information is quite scarce, but when you execute the code, you have the loss for each step and not only for the whole epoch. Of course, it is a very basic version of logging and should be improved.
In [26]:
with open(str(models_path / "model.pickle"), "rb") as f:
= pickle.load(f) model
In [27]:
model.fit_info
[{'learning_rate': 0.0010000000474974513,
'loss': [0.09700655937194824, 0.06465369462966919],
'val_loss': [0.06880103051662445, 0.05075661838054657],
'time': 524.7155420780182,
'num_points': 6628048}]
Binary classification metrics
We compute the usual binary classification metrics on the unbalanced test dataset: accuracy, balanced accuracy, confusion matrix and classification report by scikit-learn.
We report the unbalanced accuracy because the dataset is very unbalanced so the accuracy only is not a good measure of accuracy (the model can predict always False and it works quite well).
mhcflurry metrics
In [28]:
= 2 # rank threshold for positive hits
mhcflurry_rank_percentile_threshold # It comes from the mhcflurry article.
In [29]:
= pd.read_csv(str(output_path / "mhcflurry_predictions.csv"))
df = df.prediction_percentile.values <= mhcflurry_rank_percentile_threshold
y_pred = df.hit.values
y_true = accuracy_score(y_true=y_true, y_pred=y_pred)
acc = confusion_matrix(y_true=y_true, y_pred=y_pred)
confusion_mat = balanced_accuracy_score(y_true=y_true, y_pred=y_pred)
balanced_acc = classification_report(y_true=y_true, y_pred=y_pred, output_dict=False)
class_report
= ConfusionMatrixDisplay(confusion_matrix=confusion_mat)
disp
disp.plot() plt.show()
In [30]:
print(class_report)
precision recall f1-score support
False 0.99 0.98 0.99 900996
True 0.67 0.86 0.76 45423
accuracy 0.97 946419
macro avg 0.83 0.92 0.87 946419
weighted avg 0.98 0.97 0.97 946419
In [31]:
acc, balanced_acc
(0.9731936911663861, 0.9217833819652606)
The metrics are quite good. We note that we do not have a good precision on the True class (0.67), the model has a tendency to predict True too often, so we have too many False Positives. We see it on the confusion matrix, 19234 False Positives.
mhcpred metrics
In [32]:
= 0.5 # by default, but we try to tune it later mhcpred_proba_threshold
In [33]:
= pd.read_csv(str(output_path / "mhcpred_predictions.csv"))
df = df.hit.values
y_true = df.predictions.values >= mhcpred_proba_threshold
y_pred = accuracy_score(y_true=df.hit.values, y_pred=y_pred)
acc = confusion_matrix(y_true=df.hit.values, y_pred=y_pred)
confusion_mat = balanced_accuracy_score(y_true=df.hit.values, y_pred=y_pred)
balanced_acc
= classification_report(y_true=y_true, y_pred=y_pred, output_dict=False)
class_report
= ConfusionMatrixDisplay(confusion_matrix=confusion_mat)
disp
disp.plot() plt.show()
In [34]:
acc, balanced_acc
(0.9731657332258088, 0.7775326900487307)
In [35]:
print(class_report)
precision recall f1-score support
False 0.98 0.99 0.99 900725
True 0.82 0.56 0.67 45416
accuracy 0.97 946141
macro avg 0.90 0.78 0.83 946141
weighted avg 0.97 0.97 0.97 946141
mhcpred has worse performances compared to mhcflurry, see the balanced accuracy. On the True class, in that case, the recall is not good (0.56), the model has a tendency to predict False too often, on the confusion matrix we have 20000 True Negatives. It indicates that if we lower the threshold, we may improve the model.
Threshold tuning
We plot the precision recall curve to try to identify a better threshold.
In [36]:
from sklearn.metrics import precision_recall_curve, PrecisionRecallDisplay
= precision_recall_curve(y_true=y_true, probas_pred=df.predictions.values)
precision, recall, thresholds = PrecisionRecallDisplay(precision=precision, recall=recall)
disp
disp.plot()
plt.show()
In [37]:
= pd.DataFrame({
precision_recall_thresholds "precision": precision[:-1],
"recall": recall[:-1],
"thresholds": thresholds,
})
In [38]:
precision_recall_thresholds
precision | recall | thresholds | |
---|---|---|---|
0 | 0.048001 | 1.000000 | 0.000114 |
1 | 0.048001 | 1.000000 | 0.000116 |
2 | 0.048001 | 1.000000 | 0.000117 |
3 | 0.048001 | 1.000000 | 0.000125 |
4 | 0.048002 | 1.000000 | 0.000125 |
... | ... | ... | ... |
889313 | 1.000000 | 0.000110 | 0.992152 |
889314 | 1.000000 | 0.000088 | 0.992280 |
889315 | 1.000000 | 0.000066 | 0.992347 |
889316 | 1.000000 | 0.000044 | 0.992431 |
889317 | 1.000000 | 0.000022 | 0.992971 |
889318 rows × 3 columns
A threshold of approx. 0.2 seems to be a good compromise for precision/recall.
In [39]:
= 0.2 mhcpred_proba_threshold
In [40]:
= pd.read_csv(str(output_path / "mhcpred_predictions.csv"))
df = df.hit.values
y_true = df.predictions.values >= mhcpred_proba_threshold
y_pred = accuracy_score(y_true=df.hit.values, y_pred=y_pred)
acc = confusion_matrix(y_true=df.hit.values, y_pred=y_pred)
confusion_mat = balanced_accuracy_score(y_true=df.hit.values, y_pred=y_pred)
balanced_acc
= classification_report(y_true=y_true, y_pred=y_pred, output_dict=False)
class_report
= ConfusionMatrixDisplay(confusion_matrix=confusion_mat)
disp
disp.plot() plt.show()
In [41]:
acc, balanced_acc
(0.9697360118629252, 0.8462451280426622)
In [42]:
print(class_report)
precision recall f1-score support
False 0.99 0.98 0.98 900725
True 0.68 0.71 0.69 45416
accuracy 0.97 946141
macro avg 0.83 0.85 0.84 946141
weighted avg 0.97 0.97 0.97 946141
We see that we have improved the balanced accuracy. We have a deterioration of the precision but a better recall.