User Tools

Site Tools


code:demo2d

Displaying the decision boundary of a classifier

Below is Python code for displaying the decision boundary of a classifier.

To use it:

using_demo2d.py
import demo2d
from sklearn import svm
 
demo2d.get_data()
 
classifier = svm.LinearSVC()
 
demo2d.decision_surface(classifier)
demo2d.py
import numpy
from numpy import arange
import matplotlib
from matplotlib import pylab
pylab.rcParams['contour.negative_linestyle'] = 'solid'
 
"""
demo2d:  display decision boundaries and contours of the decision function
of a classifier on two dimensional data.
 
USAGE::
 
  first you need to generate some data; you need to call
  demo2d.generate_data() 
  data is generated by clicking '1' or '2' at positions on the figure
  where you want your data points to be.
  click 'q' when you're done.
  demo2d.decision_surface(classifier) then plots the decision boundary and
  contours of the discriminant function of the given classifier on the data
  that was generated.
  demo2d.decision_surface can be called several times using different classifiers.
"""
 
X = []
Y = []
#plotStr = ['or', 'ob']
plotStr = ['or', '+b']
xmin = -1
xmax = 1
ymin = -1
ymax = 1
 
def pick(event) :
 
    global X
    global Y
    key_to_class = {'1' : 0, '2' : 1}
    if event.key == 'q' :
        if len(X) == 0 : return
        print ('done creating data.  close this window and use the decisionSurface function')
        pylab.disconnect(binding_id)
    if event.key =='1' or event.key == '2' :
        if event.inaxes is not None:
            print ('data coords', event.xdata, event.ydata)
            X.append([event.xdata, event.ydata])
            Y.append(key_to_class[event.key])
            pylab.plot([event.xdata], [event.ydata], 
                       plotStr[int(event.key) - 1])
            pylab.draw()
 
def get_data(**args) :
    pylab.subplot(111)
    pylab.plot([xmin,xmin,xmax,xmax], [ymin,ymax,ymin,ymax], '.k')
    pylab.title("press the numbers 1 or 2 to generate data points and 'q' to quit")
    global binding_id
    binding_id = pylab.connect('key_press_event', pick)
    pylab.show()
 
def scatter(X, Y, **args) :
 
    markersize = 5
    if 'markersize' in args :
        markersize = args['markersize']
    for i in range(len(X)) :
        pylab.plot(X[i][0], X[i][1], plotStr[Y[i]], markersize=markersize)
 
 
def decision_surface(classifier, fileName = None, **args) :
 
    global X
    global Y
 
    classifier.fit(X, Y)
 
    numContours = 3
    if 'numContours' in args :
        numContours = args['numContours']
    title = None
    if 'title' in args :
        title = args['title']
    markersize=5
    fontsize = 'medium'
    if 'markersize' in args :
        markersize = args['markersize']
    if 'fontsize' in args :
        fontsize = args['fontsize']
    contourFontsize = 10
    if 'contourFontsize' in args :
        contourFontsize = args['contourFontsize']    
    showColorbar = False
    if 'showColorbar' in args :
        showColorbar = args['showColorbar']
    show = True
    if fileName is not None :
        show = False
    if 'show' in args :
        show = args['show']
 
    # setting up the grid
    delta = 0.01
    if 'delta' in args :
        delta = args['delta']
 
 
    x = arange(xmin, xmax, delta)
    y = arange(ymin, ymax, delta)
 
    Z = numpy.zeros((len(x), len(y)), numpy.float)
    gridX = numpy.zeros((len(x) *len(y), 2), numpy.float)
    n = 0
    for i in range(len(x)) :
        for j in range(len(y)) :
            gridX[n][0] = x[i]
            gridX[n][1] = y[j]
            n += 1
 
    results = classifier.decision_function(gridX)
 
    n = 0
    for i in range(len(x)) :
        for j in range(len(y)) :
            Z[i][j] = results[n]
            n += 1
 
    #pylab.figure()
    im = pylab.imshow(numpy.transpose(Z), 
                      interpolation='bilinear', origin='lower',
                      cmap=pylab.cm.gray, extent=(xmin,xmax,ymin,ymax) )
 
    if numContours == 1 :
        C = pylab.contour(numpy.transpose(Z),
                          [0],
                          origin='lower',
                          linewidths=(3),
                          colors = 'black',
                          extent=(xmin,xmax,ymin,ymax))
    elif numContours == 3 :
        C = pylab.contour(numpy.transpose(Z),
                          [-1,0,1],
                          origin='lower',
                          linewidths=(1,3,1),
                          colors = 'black',
                          extent=(xmin,xmax,ymin,ymax))
    else :
        C = pylab.contour(numpy.transpose(Z),
                          numContours,
                          origin='lower',
                          linewidths=2,
                          extent=(xmin,xmax,ymin,ymax))
 
    pylab.clabel(C,
                 inline=1,
                 fmt='%1.1f',
                 fontsize=contourFontsize)
 
    # plot the data
    scatter(X, Y, markersize=markersize)
    xticklabels = pylab.getp(pylab.gca(), 'xticklabels')
    yticklabels = pylab.getp(pylab.gca(), 'yticklabels')
    pylab.setp(xticklabels, fontsize=fontsize)
    pylab.setp(yticklabels, fontsize=fontsize)
 
    if title is not None :
        pylab.title(title, fontsize=fontsize)
    if showColorbar :
        pylab.colorbar(im)
 
    # colormap:
    pylab.hot()
    if fileName is not None :
        pylab.savefig(fileName)
    if show :
        pylab.show()
code/demo2d.txt ยท Last modified: 2016/09/29 14:35 by asa