This is an old revision of the document!
Below is Python code for displaying the decision boundary of a classifier.
To use it:
import demo2d from sklearn import svm demo2d.get_data() classifier = svm.LinearSVC() demo2d.decision_surface(classifier)
import numpy from numpy import arange import matplotlib from matplotlib import pylab pylab.rcParams['contour.negative_linestyle'] = 'solid' """ demo2d: display decision boundaries and contours of the decision function of a classifier on two dimensional data. USAGE:: first you need to generate some data; you need to call demo2d.generate_data() data is generated by clicking '1' or '2' at positions on the figure where you want your data points to be. click 'q' when you're done. demo2d.decision_surface(classifier) then plots the decision boundary and contours of the discriminant function of the given classifier on the data that was generated. demo2d.decision_surface can be called several times using different classifiers. """ X = [] Y = [] #plotStr = ['or', 'ob'] plotStr = ['or', '+b'] xmin = -1 xmax = 1 ymin = -1 ymax = 1 def pick(event) : global X global Y key_to_class = {'1' : 0, '2' : 1} if event.key == 'q' : if len(X) == 0 : return print 'done creating data. close this window and use the decisionSurface function' pylab.disconnect(binding_id) if event.key =='1' or event.key == '2' : if event.inaxes is not None: print 'data coords', event.xdata, event.ydata X.append([event.xdata, event.ydata]) Y.append(key_to_class[event.key]) pylab.plot([event.xdata], [event.ydata], plotStr[int(event.key) - 1]) pylab.draw() def get_data(**args) : pylab.subplot(111) pylab.plot([xmin,xmin,xmax,xmax], [ymin,ymax,ymin,ymax], '.k') pylab.title("press the numbers 1 or 2 to generate data points and 'q' to quit") global binding_id binding_id = pylab.connect('key_press_event', pick) pylab.show() def scatter(X, Y, **args) : markersize = 5 if 'markersize' in args : markersize = args['markersize'] for i in range(len(X)) : pylab.plot(X[i][0], X[i][1], plotStr[Y[i]], markersize=markersize) def decision_surface(classifier, fileName = None, **args) : global X global Y classifier.fit(X, Y) numContours = 3 if 'numContours' in args : numContours = args['numContours'] title = None if 'title' in args : title = args['title'] markersize=5 fontsize = 'medium' if 'markersize' in args : markersize = args['markersize'] if 'fontsize' in args : fontsize = args['fontsize'] contourFontsize = 10 if 'contourFontsize' in args : contourFontsize = args['contourFontsize'] showColorbar = False if 'showColorbar' in args : showColorbar = args['showColorbar'] show = True if fileName is not None : show = False if 'show' in args : show = args['show'] # setting up the grid delta = 0.01 if 'delta' in args : delta = args['delta'] x = arange(xmin, xmax, delta) y = arange(ymin, ymax, delta) Z = numpy.zeros((len(x), len(y)), numpy.float) gridX = numpy.zeros((len(x) *len(y), 2), numpy.float) n = 0 for i in range(len(x)) : for j in range(len(y)) : gridX[n][0] = x[i] gridX[n][1] = y[j] n += 1 results = classifier.decision_function(gridX) n = 0 for i in range(len(x)) : for j in range(len(y)) : Z[i][j] = results[n] n += 1 #pylab.figure() im = pylab.imshow(numpy.transpose(Z), interpolation='bilinear', origin='lower', cmap=pylab.cm.gray, extent=(xmin,xmax,ymin,ymax) ) if numContours == 1 : C = pylab.contour(numpy.transpose(Z), [0], origin='lower', linewidths=(3), colors = 'black', extent=(xmin,xmax,ymin,ymax)) elif numContours == 3 : C = pylab.contour(numpy.transpose(Z), [-1,0,1], origin='lower', linewidths=(1,3,1), colors = 'black', extent=(xmin,xmax,ymin,ymax)) else : C = pylab.contour(numpy.transpose(Z), numContours, origin='lower', linewidths=2, extent=(xmin,xmax,ymin,ymax)) pylab.clabel(C, inline=1, fmt='%1.1f', fontsize=contourFontsize) # plot the data scatter(X, Y, markersize=markersize) xticklabels = pylab.getp(pylab.gca(), 'xticklabels') yticklabels = pylab.getp(pylab.gca(), 'yticklabels') pylab.setp(xticklabels, fontsize=fontsize) pylab.setp(yticklabels, fontsize=fontsize) if title is not None : pylab.title(title, fontsize=fontsize) if showColorbar : pylab.colorbar(im) # colormap: pylab.hot() if fileName is not None : pylab.savefig(fileName) if show : pylab.show()