혼자 공부하는 머신러닝 딥러닝

훈련 세트와 테스트 세트 추출하기

temporubato108 2024. 12. 3. 18:52

fit()에 사용한 데이터로 score()을 수행하면 모델은 데이터를 100% 판별할 수 밖에 없다.

데이터를 훈련 세트와 테스트 세트로 나누어 모델을 훈련하고, 평가해보자.


 
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
 

필요한 라이브러리 및 클래스를 import 한다.

 

 
# 도미와 빙어 데이터 준비
fish_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0,
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0,
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0, 9.8,
                10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
fish_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0,
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0,
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0, 6.7,
                7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

# 2차원 리스트 생성
fish_data = [[l,w] for l,w in zip(fish_length, fish_weight)]
fish_target = [1]*35 + [0]*14
 

저번과 동일하므로 생략



# 슬라이싱을 사용하여 train test 데이터 준비
train_input = fish_data[:35]
train_target = fish_target[:35]
test_input = fish_data[35:]
test_target = fish_target[35:]

# 모델 훈련 후 평가 (샘플이 섞여있지 않으므로 0점이 출력된다)
kn = KNeighborsClassifier()
kn.fit(train_input, train_target)
score = kn.score(test_input, test_target)
print(score) # 0.0

파이썬 리스트의 '슬라이싱'기능을 사용하여 데이터를 분리해본다.

데이터 중 처음 35개를 훈련 세트, 나머지 14개를 테스트 세트로 선택한 후,

모델 훈련과 평가를 했더니 0점이 출력된다.

 

도미 35개, 빙어 14개가 순서대로 들어있는 데이터에서 앞 35개만 추출해서 데이터가 섞이지 않은 것이다.

모델은 모든 데이터를 도미로 예측하여(도미 데이터만 훈련했으니 당연하다)

정작 빙어만 들어있는 테스트 데이터는 하나도 맞추지 못한 것이다.

 

이렇게 훈련 세트와 테스트 세트에 샘플이 골고루 섞이지 않은 상태를 '샘플링 편향'이라고 한다.

 



# 넘파이를 사용하여 배열로 변환
# 넘파이는 2차원 데이터를 array(행과 열)로 변환하여 출력해준다.
input_arr = np.array(fish_data)
target_arr = np.array(fish_target)
print(input_arr.shape) # (49, 2)

# 랜덤한 인덱스를 생성
# arange() 함수로 0~48의 인덱스를 만들고, 랜덤하게 섞는다.
np.random.seed(42)
index = np.arange(49)
np.random.shuffle(index)

print(index)
"""
[13 45 47 44 17 27 26 25 31 19 12  4 34  8  3  6 40 41 46 15  9 16 24 33
 30  0 43 32  5 29 11 36  1 21  2 37 35 23 39 10 22 18 48 20  7 42 14 28
 38]
"""
 

랜덤 데이터를 추출하기 위하여 numpy를 사용한다.

넘파이는 2차원 데이터를 array형태로 출력하며,

출력 결과에서 행과 열을 구분하여 데이터의 구조를 쉽게 파악할 수 있다.

넘파이의 shape속성을 사용하면 array의 크기를 간단히 확인할 수 있다.

 

우선 랜덤 인덱스를 생성하기 위하여,

넘파이의 arange() 함수와 random패키지의 shuffle() 함수를 사용한다.

만들어진 랜덤 인덱스를 index로 정의하고 print하면 0에서 48까지 정수가 랜덤하게 출력된다.



# 넘파이의 배열 인덱싱을 사용하여 한 번에 여러 원소를 추출한다.
# 넘파이 array와 랜덤index를 활용하여 랜덤한 훈련 세트와 테스트 세트를 추출한다.
train_input = input_arr[index[:35]]
train_target = target_arr[index[:35]]
test_input = input_arr[index[35:]]
test_target = target_arr[index[35:]]

# 랜덤 index의 첫 번째 값이 13이므로 train_input의 첫 번째 원소는 input_arr[13]과 일치한다.
print(input_arr[13], train_input[0]) # [ 32. 340.] [ 32. 340.]
 

이번에는 넘파이의 '배열 인덱싱' 기능을 사용하여, 여러 원소를 한번에 추출한다.

이 때, 아까 만들어두었던 랜덤 인덱스를 활용하면, 랜덤으로 데이터를 추출하는 것이 가능해진다.

 

index 배열의 처음 35개를 전달하여 훈련 세트로 만든다.

그리고 나머지 14개를 테스트 세트로 만든다.



# 훈련 세트와 테스트 세트에 데이터가 잘 섞여 있는지 확인
# 현재 train_input은 array형태(2차원)이다.
plt.scatter(train_input[:,0], train_input[:,1])
plt.scatter(test_input[:,0], test_input[:,1])
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

산점도로 데이터의 분포를 확인해본다.

 

파란색이 훈련 세트, 주황색이 테스트 세트이다.

양쪽 모두 잘 섞여 있는 것을 확인할 수 있다.

 



# 새로운 KNeighborsClassifier객체 kn2를 만든다.
# 모델 훈련 후 평가 (점수는 1.0이 출력된다.)
kn2 = KNeighborsClassifier()
kn2.fit(train_input, train_target)
score2 = kn2.score(test_input, test_target)
print(score2) # 1.0

새로운 KNeighborsClassifier 객체인 kn2를 생성한다.

데이터를 섞어 만든 훈련 세트로 모델을 훈련시킨다.

그 후 테스트 데이터로 모델을 테스트한다.

 

모델 점수는 1.0이 출력된다.

 



# 마지막으로 predict()로 테스트 세트의 예측 결과와 실제 target을 비교한다.
predict = kn2.predict(test_input)
print(predict) # [0 0 1 0 1 1 1 0 1 1 0 1 1 0]
print(test_target) # [0 0 1 0 1 1 1 0 1 1 0 1 1 0]
print(predict == test_target)
"""
[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True]
"""
 

predict() 메서드를 사용해서 kn2모델의 예측 결과를 출력하여 실제 타겟과 비교해본다.

전부 일치하는 것을 확인할 수 있다.