""" explore the confusion matrix and ROC curve using scikit-learn http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html """ import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, auc from sklearn import cross_validation from sklearn.linear_model import LogisticRegression from sklearn.metrics import confusion_matrix import perceptron2 ############################################################################### # first let's read in the heart dataset data=np.genfromtxt("../data/heart_scale.data", delimiter=",") X=data[:,1:] y=data[:,0] X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.4, random_state=3) ############################################# # confusion matrix classifier = perceptron2.Perceptron() classifier.fit(X_train, y_train) y_pred = classifier.predict(X_test) # By definition a confusion matrix C is such that C_{i, j} is equal to the # number of observations known to be in group i but predicted to be in group j. print(confusion_matrix(y_test, y_pred)) ############################################# # ROC analysis # compute ROC curve for the perceptron classifier = perceptron2.Perceptron() classifier.fit(X_train, y_train) scores = classifier.decision_function(X_test) fpr_per, tpr_per, _ = roc_curve(y_test, scores) # compute ROC curve for logistic regression # we'll compare two variants with varying levels # of regularization; note that the parameter C of # the scikit-learn LogisticRegression class is inversely # proportional to the strength of regularization classifier = LogisticRegression(C=0.1) classifier.fit(X_train, y_train) scores = classifier.decision_function(X_test) fpr_lr, tpr_lr, _ = roc_curve(y_test, scores) classifier = LogisticRegression(C=1000) classifier.fit(X_train, y_train) scores = classifier.decision_function(X_test) fpr_lr2, tpr_lr2, _ = roc_curve(y_test, scores) plt.figure(1) plt.plot([0, 1], [0, 1], 'k--') plt.plot(fpr_per, tpr_per, label='perceptron') plt.plot(fpr_lr, tpr_lr, label='logistic regression') plt.plot(fpr_lr2, tpr_lr2, label='logistic regression') plt.xlabel('False positive rate') plt.ylabel('True positive rate') plt.title('ROC curve') plt.legend(loc='best') plt.show()