1 분 소요

원본 사이트: https://scikit-learn.org/stable/auto_examples/svm/plot_svm_margin.html

SVM(Support Vector Machine) 마진 예시

아래 그림은 파라미터 C가 분리 선에 미치는 영향을 보여줍니다. C 값이 크면 기본적으로 우리 모델은 데이터 분포에 대한 신뢰도가 크지 않으며, 분리 선에 가까운 점만 고려한다는 것을 알 수 있습니다.


C값이 작으면 더 많은/모든 관측치들이 포함되므로 해당 영역의 모든 데이터를 사용하여 마진을 계산할 수 있습니다.

NOTE: 마진이란 결정 경계와 서포트 벡터 사이의 거리를 의미합니다.

image.png

image.png

# Code source: Gaël Varoquaux
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from sklearn import svm

# 분리가 가능한 40개의 데이터 포인트들을 생성
np.random.seed(0)
X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]     # (40, 2)
Y = [0] * 20 + [1] * 20

# 그래프 번호
fignum = 1

# 모델 학습하기
for name, penalty in (("unreg", 1), ("reg", 0.05)):

    clf = svm.SVC(kernel="linear", C=penalty)
    clf.fit(X, Y)

    # 분리 초평면을 구합니다.
    w = clf.coef_[0]
    a = -w[0] / w[1]
    xx = np.linspace(-5, 5)
    yy = a * xx - (clf.intercept_[0]) / w[1]

    # 서포트 벡터를 통과하는 초평면(초평면에 수직인 방향으로 초평면에서 멀어지는 마진)을 그립니다.
    # 이것은 2차원에서 결정경계로부터 수직으로 sqrt(1+a^2) 만큼 떨어져 있습니다.
    margin = 1 / np.sqrt(np.sum(clf.coef_ ** 2))
    yy_down = yy - np.sqrt(1 + a ** 2) * margin
    yy_up = yy + np.sqrt(1 + a ** 2) * margin

    # 평면에 선, 데이터 포인트 및 결정 경계에서 가장 가까운 벡터를 그립니다.
    plt.figure(fignum, figsize=(4, 3))
    plt.clf()
    plt.plot(xx, yy, "k-")
    plt.plot(xx, yy_down, "k--")
    plt.plot(xx, yy_up, "k--")

    plt.scatter(
        clf.support_vectors_[:, 0],
        clf.support_vectors_[:, 1],
        s=80,
        facecolors="none",
        zorder=10,
        edgecolors="k",
        cmap=cm.get_cmap("RdBu"),
    )
    plt.scatter(
        X[:, 0], X[:, 1], c=Y, zorder=10, cmap=cm.get_cmap("RdBu"), edgecolors="k"
    )

    plt.axis("tight")
    x_min = -4.8
    x_max = 4.2
    y_min = -6
    y_max = 6

    XX, YY = np.meshgrid(xx, yy)
    xy = np.vstack([XX.ravel(), YY.ravel()]).T
    Z = clf.decision_function(xy).reshape(XX.shape)

    # 결과를 등고선 그래프에 표현합니다.
    plt.contourf(XX, YY, Z, cmap=cm.get_cmap("RdBu"), alpha=0.5, linestyles=["-"])

    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)

    plt.xticks(())
    plt.yticks(())
    fignum = fignum + 1

plt.show()
<Figure size 288x216 with 1 Axes>
<Figure size 288x216 with 1 Axes>

ⓒ 2007 - 2021, scikit-learn developers (BSD License). Show this page source

댓글남기기