AI/MachineLearning

SVM(Support Vector Machine)

향식이 2021. 6. 15. 14:23

서포트 벡터 머신은 높은 성능을 보여주는 대표적인 분류 알고리즘입니다.

특히 이진 분류를 위해 주로 사용되는 알고리즘으로, 각 클래스의 가장 외곽의 데이터들 즉, 서포트 벡터들이 가장 멀리 떨어지도록 합니다. 

SVM을 위한 사이킷런 함수/라이브러리

  • from sklearn.svm import SVC: SVM 모델을 불러옵니다.
  • SVC(): SVM 모델을 정의합니다. 
  • [Model].fit(x, y): (x,y) 데이터 셋에 대해서 모델을 학습시킵니다.
  • [Model].predict(x): x 데이터를 바탕으로 예측되는 값을 출력합니다.
import pandas as pd  
import numpy as np  
import matplotlib.pyplot as plt  

import warnings
warnings.filterwarnings(action='ignore')

from sklearn.model_selection import train_test_split  
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix 

def load_data():
    
    data = pd.read_csv('data/dataset.csv')
    
    X = data.drop('Class', axis=1)
    y = data['Class']
    
    train_X, test_X, train_y, test_y = train_test_split(X, y, test_size = 0.2, random_state = 0)
    print(X, y)
    return train_X, test_X, train_y, test_y
    
def SVM(train_X, test_X, train_y, test_y):
    
    svm = SVC()
    
    svm.fit(train_X, train_y)
    
    pred_y = svm.predict(test_X)
    
    return pred_y
    
# 데이터를 불러오고, 모델 예측 결과를 확인하는 main 함수입니다.
def main():
    
    train_X, test_X, train_y, test_y = load_data()
    
    pred_y = SVM(train_X, test_X, train_y, test_y)
    
    # SVM 분류 결과값을 출력합니다.
    print("\nConfusion matrix : \n",confusion_matrix(test_y,pred_y))  
    print("\nReport : \n",classification_report(test_y,pred_y)) 

if __name__ == "__main__":
    main()

 

출처: 앨리스 교육

반응형