본문 바로가기
개발/Python

결정 트리 (분기 전후의 지니 불순도 감소량)

by 피로물든딸기 2025. 8. 15.
반응형

전체 링크

 

지니 불순도 정의는 다음과 같다.

 

분기 전후의 지니 불순도 감소량은 다음과 같다.

 

위 공식을 이용하여 아래 데이터에 대해 첫 번째 분기 전후의 지니 불순도 감소량을 구하라.

import pandas as pd
import numpy as np

np.random.seed(1234)

df = pd.DataFrame({
    'feature1': np.random.normal(0, 1, 100),
    'feature2': np.random.normal(50, 20, 100),
    'target': np.random.choice([0, 1], size=100)
})

X = df[['feature1', 'feature2']]
y = df['target']

df.head()

from sklearn.tree import DecisionTreeClassifier

model = DecisionTreeClassifier(random_state=1234, max_depth=4)
model.fit(X, y)

DecisionTreeClassifiertree_ 속성을 이용하면 분기 전후의 불순도 감소량을 구할 수 있다.

tree.value[i]는 i번째 노드의 샘플 분포가 된다.

for i in range(model.tree_.node_count):
    print(model.tree_.value[i], model.tree_.impurity[i])

 

위의 경우 0번째 샘플55, 45로 나뉘어지고 지니 불순도 정의에 의해 0.495임을 알 수 있다.

1 - (55 / 100) ** 2 - (45 / 100) ** 2

 

최초에 클래스 0이 55개, 1이 45개로 분기가 시작된다.

df['target'].value_counts()

 

[55, 45] → [5, 0] + [50, 45]

루트의 왼쪽 끝 노드는 [5, 0]으로 완전 순수하게 분류가 완료되었고,

루트의 오른쪽 끝 노드는 [50, 45]로 분류가 되는 것을 알 수 있다. 

(이후 [50, 45] → [40, 42] + [10, 3])

 

root_node의 번호는 0이기 때문에 tree의 속성을 이용해 아래와 같이 지니 불순도 감소량을 구할 수 있다.

 

children_left / right [node] → 왼쪽 / 오른쪽 노드 번호 획득

impurity[node] → 해당 노드의 불순도

n_node_samples[node] → 해당 노드의 샘플 개수

# 지니 불순도 감소량
tree_ = model.tree_

# 노드 정보
root_node = 0
left_child = tree_.children_left[root_node]
right_child = tree_.children_right[root_node]

# 노드별 impurity
impurity_root = tree_.impurity[root_node]
impurity_left = tree_.impurity[left_child]
impurity_right = tree_.impurity[right_child]

# 노드별 sample 수
n_root = tree_.n_node_samples[root_node]
n_left = tree_.n_node_samples[left_child]
n_right = tree_.n_node_samples[right_child]

# 지니 불순도 감소량 계산
weighted_impurity = (n_left / n_root) * impurity_left + (n_right / n_root) * impurity_right
gini_decrease = impurity_root - weighted_impurity

# 출력
print(f"Root impurity: {impurity_root:.4f}")
print(f"Left impurity: {impurity_left:.4f}, samples: {n_left}")
print(f"Right impurity: {impurity_right:.4f}, samples: {n_right}")
print(f"Impurity decrease after split: {gini_decrease:.4f}")

반응형

'개발 > Python' 카테고리의 다른 글

결정 트리 (부스팅)  (3) 2025.08.15
statsmodels ols 결과 출력하기  (1) 2025.08.15
일원분산분석 비교 (f_oneway vs anova_lm)  (2) 2025.08.15
데이터 범주화, 구간화  (1) 2025.08.15
정규성 검정  (2) 2025.08.15

댓글