데이터 분석

트리 기반 앙상블 모델에서 feature importance 란?

hoonhoon04 2023. 2. 2. 09:52

xgboost, lightgbm과 같은 트리 기반 앙상블 모델은 크게 classification(분류), regression(회귀)의 문제로 나뉜다. 각각의 경우 어떻게 앙상블 모델의 feature importance를 구하는지 알아보겠다. 그리고 더 나아가 기존의 built in feature importance의 단점을 보완한 permutation importance에 대해서도 다뤄보겠다.

 

classification(분류)를 위한 트리 기반 앙상블 모델

우리가 위의 그림처럼 이진분류를 한다고 가정해보자. 분류가 한쪽으로 몰릴수록 잘 분류된 것이므로 impurity(불순도)라는 개념을 도입한다.

 

1에서 비율의 제곱합들을 빼주는데 (0.5,0.5) 비율로 나눠질 때 가장 큰 값을 가지고 (1,0) 비율로 나눠질 때 가장 작아집니다. 그래서 결국 불순도를 가장 크게 감소시키는 feature를 분류기준으로 계속 선택해 나간다고 보면 된다.

 

regression(회귀)를 위한 트리 기반 앙상블 모델

 

분류의 경우를 보면 비율을 통해 구하는 것을 알 수 있는데 그럼 회귀는 어떻게 하지 생각할 수 있다. 보통 회귀를 위한 트리 기반 앙상블 모델은 한 노드에 속한 모든 값들의 평균 값으로 예측값을 계산한다.

 

위 그림을 보면 squared_error가 적혀있는데 regression의 경우는 불순도를 MSE(mean squared error)로 사용한다. 즉 MSE를 가장 크게 줄이는 feature를 찾아나가는데 보통 MSE는 실제값에서 예측값을 빼고 제곱하여 평균을 낸다. (참고로 여기서 말한 예측값은 한 노드에 속한 데이터들의 평균으로 모두 같다.)

 

정확하게 수식을 적으면 아래와 같이 나타낼 수 있다

.

위 예시의 MedInc feature의 importance는 아래 식처럼 구해진다.

 

이것을 모든 feature들이 분리될 때마다 더해주게 되는데 보다보면 의문점이 하나 들 수 있다.

 

❓ 많이 등장하면 등장할수록 feature importance 가 커지는 것이 아닌가?

 

실제로 구하는 방식을 보면 당연히 분류의 기준으로 많이 등장할수록 feature importance(정확히 말하면 모델의 built in feature importance)는 커질 수밖에 없다. 이러한 배경지식을 가진 상태로 kaggle의 Predict Future Sales 대회에 참가하여 사용하였던 우리의 모델을 예시로 한 번 살펴보겠다.

 

일단 우리의 모델의 built in feature importance를 그려보기 위해 아래처럼 함수를 정의하고 출력해보았다.

from xgboost import plot_importance

# lightgbm 모델의 feature importance
def plot_features_lgb(booster, figsize):    
    fig, ax = plt.subplots(1,1,figsize=figsize)
    return lgb.plot_importance(booster=booster, ax=ax)

# xgboost 모델의 feature importance
def plot_features_xgb(booster, figsize):    
    fig, ax = plt.subplots(1,1,figsize=figsize)
    return plot_importance(booster=booster, ax=ax)
# lightgbm 의 feature importance 출력
plot_features_lgb(lgb_model, (10,14))
# xgboost 의 feature importance 출력
plot_features_xgb(xgb, (10,14))

 

위에서 이야기했듯이 xgboost와 lightgbm 모델을 앙상블 했기 때문에 각각 모델의 feature importance를 구해보았다. feature들이 상당히 많으므로 상위 5개와 하위 5개만 살펴보도록 하겠다.

 

<xgboost>

상위 5개

  1. item_id : 20494
  2. item_category_id : 83
  3. category_proportion : 2297
  4. date_block_num : 32
  5. item_id_mean_sales_lag1 : 2071

하위 5개

  1. shop_group_cluster : 4
  2. item_count_lag2 : 387
  3. shop_period_lag3 : 14
  4. item_count_lag3 : 385
  5. brand_new : 2

<lightgbm>

상위 5개

  1. item_id : 20494
  2. date_block_num : 32
  3. item_count_days_lag1 : 458
  4. item_id_mean_sales_lag1 : 2071
  5. category_proportion : 2297

하위 5개

  1. shop_period_lag2 : 13
  2. item_count_lag2 : 387
  3. item_count_lag3 : 385
  4. shop_period_lag3 : 14
  5. brand_new : 2

feature 옆에 적힌 숫자들은 각각 cardinality(종류의 개수)를 나타내는데 상위권과 하위권의 확실한 개수의 차이를 볼 수 있다. 전체 feature들의 cardinality를 살펴보면

cardinality가 많은 건 40000개, 10000개 같은 것들이 있는데 하나도 상위 5위안에 들지 않아서 우리 모델이 과적합이 많이 일어나지 않은 모델이라는 점을 알 수 있었다. 또한, date_block_num 같은 것이 두 모델 모두 상위권을 차지했던 점은 우리가 lagging을 통해 시계열의 흐름을 잘 파악했다는 것으로도 판단할 수 있었다.

 

이제 하위 5개들을 보면서 어떤 feature가 불필요하게 성능을 저하시켰는지 확인해 보겠다. 그전에 앞에서도 말했듯이 built in feature importance는 cardinality에 편향되므로 다른 판단기준이 필요하였다. 따라서 permutation feature importance를 알아보도록 하겠다.

 

permutation feature importance 란?

permutation importance는 특정 한 feature를 골라서 그 feature만 막 섞고 모델 성능의 변화를 살펴보는 것이다.

 

쉽게 예시를 들면 어느 축구팀에서 에이스라고 생각되는 A라는 선수가 존재한다고 하자. 어느 날 A가 갑자기 부상을 당해서 다른 선수가 출전하였는데 팀의 경기력이 비슷하거나 훨씬 좋아졌다면 사실 A는 에이스가 아니라고 판단할 수 있다.

이런 느낌으로 특정 feature가 막 섞였을 때 성능이 변화가 없거나 증가했다면 그 feature는 사실 중요하지 않았던 것이고 반대로 성능저하가 크게 나타났다면 중요한 feature로 작용했음을 알 수 있다.

 

우리가 시도한 xgboost 모델의 permutation importance를 구해보면 아래와 같다.

from sklearn.inspection import permutation_importance

features = dict()

# xgboost 모델의 permutation importance
r = permutation_importance(xgb, X_valid, y_valid, n_repeats=10, random_state=0)

for i in r.importances_mean.argsort()[::-1]:
    if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
        features[xgb.get_booster().feature_names[i]] = r.importances[i]   
        print(f"{xgb.get_booster().feature_names[i]} : " f"{r.importances_mean[i]:.3f}"
               f" +/- {r.importances_std[i]:.3f}")

built in feature importance와 비교해 보면 공통적으로 3개월이 lagging 된 lag3 feature들이 전체적으로 낮은 성능을 보였고 반대로 brand_new 는 순위차이가 상당하게 난 것을 볼 수 있었다. 그래서 각각 따로 제거했을 때 원래 모델에 비해 성능이 어떻게 되는지 살펴보았고 두 경우 모두 원래보다 성능이 떨어지는 것으로 보아 모델을 그냥 두는 것이 최선임을 알 수 있었다.

 

또한 1개월 lagging된 feature들이 상위권에 주로 분포되어 있음을 볼 수 있는데 우리의 데이터는 계절성(주기성)에 영향을 많이 받기보다는 근접한 시간에 매출이 영향을 많이 받는다는 것도 알 수 있었다.