This shows you the differences between two versions of the page.
— |
code:nearest_neighbors [2016/11/07 19:53] (current) asa created |
||
---|---|---|---|
Line 1: | Line 1: | ||
+ | |||
+ | ===== Nearest neighbor classification ==== | ||
+ | |||
+ | First some code for plotting the results: | ||
+ | |||
+ | <file python decision_boundary> | ||
+ | |||
+ | import numpy as np | ||
+ | import matplotlib.pyplot as plt | ||
+ | from matplotlib.colors import ListedColormap | ||
+ | |||
+ | def plot_boundary(classifier, X, y) : | ||
+ | |||
+ | classifier.fit(X, y) | ||
+ | |||
+ | h = .02 # mesh size | ||
+ | # Create color maps | ||
+ | cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) | ||
+ | cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) | ||
+ | x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 | ||
+ | y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 | ||
+ | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), | ||
+ | np.arange(y_min, y_max, h)) | ||
+ | Z = classifier.predict(np.c_[xx.ravel(), yy.ravel()]) | ||
+ | |||
+ | # Put the result into a color plot | ||
+ | Z = Z.reshape(xx.shape) | ||
+ | plt.figure() | ||
+ | plt.pcolormesh(xx, yy, Z, cmap=cmap_light) | ||
+ | |||
+ | # Plot also the training points | ||
+ | plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold) | ||
+ | plt.xlim(xx.min(), xx.max()) | ||
+ | plt.ylim(yy.min(), yy.max()) | ||
+ | plt.show() | ||
+ | |||
+ | </file> | ||
+ | |||
+ | <file python nearest_neighbors.py> | ||
+ | |||
+ | """ | ||
+ | Nearest neighbor classification with scikit-learn | ||
+ | full details at: | ||
+ | http://scikit-learn.org/stable/modules/neighbors.html#classification | ||
+ | """ | ||
+ | |||
+ | import numpy as np | ||
+ | from sklearn import neighbors, datasets | ||
+ | import decision_boundary | ||
+ | |||
+ | # import some data to play with | ||
+ | iris = datasets.load_iris() | ||
+ | X = iris.data[:, :2] # take the first two features. | ||
+ | y = iris.target | ||
+ | |||
+ | # the parameters of the scikit-learn nearest neighbor | ||
+ | # classifier: | ||
+ | # sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, | ||
+ | # weights='uniform', algorithm='auto', leaf_size=30, p=2, | ||
+ | # metric='minkowski') | ||
+ | # weights refers to how to weight each example | ||
+ | # 'algorithm' is the choice of algorithm for storing the | ||
+ | # training data ('brute', 'ball_tree', 'kd-tree') | ||
+ | # complete description of the available metrics: | ||
+ | # http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.DistanceMetric.html#sklearn.neighbors.DistanceMetric | ||
+ | |||
+ | classifier = neighbors.KNeighborsClassifier(n_neighbors=10) | ||
+ | |||
+ | decision_boundary.plot_boundary(classifier, X, y) | ||
+ | |||
+ | </file> | ||