User Tools

Site Tools


code:nearest_neighbors

Differences

This shows you the differences between two versions of the page.

Link to this comparison view

code:nearest_neighbors [2016/11/07 19:53] (current)
asa created
Line 1: Line 1:
 +
 +===== Nearest neighbor classification ====
 +
 +First some code for plotting the results:
 +
 +<file python decision_boundary>​
 +
 +import numpy as np
 +import matplotlib.pyplot as plt
 +from matplotlib.colors import ListedColormap
 +
 +def plot_boundary(classifier,​ X, y) :
 +
 +    classifier.fit(X,​ y)
 +
 +    h = .02  # mesh size
 +    # Create color maps
 +    cmap_light = ListedColormap(['#​FFAAAA',​ '#​AAFFAA',​ '#​AAAAFF'​])
 +    cmap_bold = ListedColormap(['#​FF0000',​ '#​00FF00',​ '#​0000FF'​])
 +    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
 +    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
 +    xx, yy = np.meshgrid(np.arange(x_min,​ x_max, h),
 +                         ​np.arange(y_min,​ y_max, h))
 +    Z = classifier.predict(np.c_[xx.ravel(),​ yy.ravel()])
 +
 +    # Put the result into a color plot
 +    Z = Z.reshape(xx.shape)
 +    plt.figure()
 +    plt.pcolormesh(xx,​ yy, Z, cmap=cmap_light)
 +
 +    # Plot also the training points
 +    plt.scatter(X[:,​ 0], X[:, 1], c=y, cmap=cmap_bold)
 +    plt.xlim(xx.min(),​ xx.max())
 +    plt.ylim(yy.min(),​ yy.max())
 +    plt.show()
 +
 +</​file>​
 +
 +<file python nearest_neighbors.py>​
 +
 +"""​
 +Nearest neighbor classification with scikit-learn
 +full details at:
 +http://​scikit-learn.org/​stable/​modules/​neighbors.html#​classification
 +"""​
 +
 +import numpy as np
 +from sklearn import neighbors, datasets
 +import decision_boundary
 +
 +# import some data to play with
 +iris = datasets.load_iris()
 +X = iris.data[:,​ :2]  # take the first two features.
 +y = iris.target
 +
 +# the parameters of the scikit-learn nearest neighbor
 +# classifier:
 +# sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,​
 +# weights='​uniform',​ algorithm='​auto',​ leaf_size=30,​ p=2,
 +# metric='​minkowski'​) ​
 +# weights refers to how to weight each example
 +# '​algorithm'​ is the choice of algorithm for storing the
 +# training data ('​brute',​ '​ball_tree',​ '​kd-tree'​)
 +# complete description of the available metrics:
 +# http://​scikit-learn.org/​stable/​modules/​generated/​sklearn.neighbors.DistanceMetric.html#​sklearn.neighbors.DistanceMetric
 +
 +classifier = neighbors.KNeighborsClassifier(n_neighbors=10)
 +
 +decision_boundary.plot_boundary(classifier,​ X, y)
 +
 +</​file>​
  
code/nearest_neighbors.txt ยท Last modified: 2016/11/07 19:53 by asa