1 분 소요

원본 사이트: https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html


최근접 이웃 분류

최근접 이웃 분류의 샘플 사용법입니다.

각 클래스에 대한 결정 경계를 표시합니다.

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns     # 좀 더 좋은 시각화를 위한 seaborn
from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets

n_neighbors = 15    # 15 개의 최근접 이웃

# 다룰 데이터 세트를 불러옵니다.
iris = datasets.load_iris()

# 처음 2 개의 특성만 다룹니다. 
# 2 차원 데이터 세트 슬라이싱
X = iris.data[:, :2]
y = iris.target

h = 0.02  # 나타낼 데이터 간의 그리드 간격

# 컬러 맵 설정
cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
cmap_bold = ["darkorange", "c", "darkblue"]

for weights in ["uniform", "distance"]:   # 예측에 사용할 2 가지의 가중 함수
    # 최근접 이웃 분류기를 만들고 데이터를 학습시킵니다.
    clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights)
    clf.fit(X, y)

    # 결정 경계를 표시하고 [x_min, x_max] x [y_min, y_max] 그리드의
    # 각 포인트에 색을 할당합니다.
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # [x_min, x_max] x [y_min, y_max]에 해당하는 모든 점의 좌표를 h 의 간격으로 구함
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])  # 구한 모든 점에 대해 예측

    # color plot에 결과 나타내기
    Z = Z.reshape(xx.shape)
    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, Z, cmap=cmap_light)

    # seaborn 으로 각 클래스에 따라 다르게 표현합니다.
    sns.scatterplot(
        x=X[:, 0],
        y=X[:, 1],
        hue=iris.target_names[y],
        palette=cmap_bold,
        alpha=1.0,
        edgecolor="black"
    )
    plt.xlim(xx.min(), xx.max())    # x 축 값의 범위: xx.min() ~ xx.max()
    plt.ylim(yy.min(), yy.max())    # y 축 값의 범위: yy.min() ~ yy.max()
    plt.title(
        "3-Class classification (k = %i, weights = '%s')" % (n_neighbors, weights)
    )
    plt.xlabel(iris.feature_names[0])
    plt.ylabel(iris.feature_names[1])

plt.show()

© 2007 - 2021, scikit-learn developers (BSD License). Show this page source

댓글남기기