🍉「西瓜书」决策树公式4.8之密度信息增益计算过程

连续值处理

公式如下:

\[\begin{align} \textup{Gain}(D, a) &= \underset {t \in T_a} {\max} ~\textup{Gain}(D, a, t)\tag{4.8a} \\ &=\underset {t \in T_a} {\max}~ \textup{Ent}(D) - \sum_{\lambda \in\{-,+\}} \frac {|D_{t}^{\lambda}|} {|D|} \textup{Ent}(D_{t}^{\lambda}), \tag{4.8b} \label {4.8} \end{align}\]

$T_a$ 是一个集合,$t$ 是具体的划分点,$\underset {t \in T_a} {\max} ~\textup{Gain}(D, a, t)$ 这条公式是指取信息增益最大的划分点,那么就意味着要把每个划分点都代进去算一遍才知道结果。

密度信息增益计算

下面以书上的密度为例,去计算“密度”的信息增益。

\[\begin{align} T_{密度} = \{ & 0.244, 0.294, 0.351, 0.381,\\ & 0.420, 0.459, 0.518, 0.574,\\ & 0.600, 0.621, 0.636, 0.648,\\ & 0.661, 0.681, 0.708, 0.746\}. \end{align}\]

第一个划分点0.244计算,$\textup{Gain}(D,密度,0.244)$,查看数据集3.0密度小于0.244的数据有{10},大于0.244的数据有 {1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17}. 即 $D_t^{\lambda}$ 的取值如下: \(\begin{align} D_{0.244}^- &= \{10\}\\ D_{0.244}^+ & = \{1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17\} \end{align}\)

把上面结果代入式$\ref {4.8}$

\[\begin{align} &\textup{Gain}(D, 密度, 0.244) \\&= \textup{Ent}(D)-(\frac {|D_{0.244}^-|} {|D|} \textup{Ent}(D_{0.244}^-) +\frac {|D_{0.244}^+|} {|D|} \textup{Ent}(D_{0.244}^+))\\ &=-(\frac {8} {17} log_2 \frac {8} {17} + \frac {9} {17} log_2 \frac {9} {17}) -(\frac {1} {17} \textup{Ent}(D_{0.244}^-) + \frac {16} {17}\textup{Ent}(D_{0.244}^+))\\ &=0.998-[-\frac {1} {17}(\frac {0} {1} log_2 \frac {0} {1} + \frac {1} {1} log_2 \frac {1} {1})+\frac {16} {17}\textup{Ent}(D_{0.244}^+)\\ &=0.998 - [0 + \frac {16} {17} \times -(\frac {8} {16}log_2 \frac {8} {16} + \frac {8} {16} log_2 \frac {8} {16})]\\ &=0.998-0.941\\ &=0.057 \end{align}\]

结果我是用python计算的,以下是计算 $\textup{Ent}(D)$ 的代码:

ent_D = -1 * ((8/17) * math.log2(8/17) + (9/17) * math.log2(9/17))
print("ent_D=", round(ent_D, 3))

其它的划分点计算类似

\[\begin{align} \textup{Gain}(D, 密度, 0.244) &= 0.057\\ \textup{Gain}(D, 密度, 0.294) &=0.118\\ \textup{Gain}(D, 密度, 0.351) &=0.117\\ \textup{Gain}(D, 密度, 0.381) &=0.262\\ \textup{Gain}(D, 密度, 0.420) &= 0.093\\ \textup{Gain}(D, 密度, 0.459) &= \\ \textup{Gain}(D, 密度, 0.518) &= \\ \textup{Gain}(D, 密度, 0.574) &= \\ \textup{Gain}(D, 密度, 0.600) &= \\ \textup{Gain}(D, 密度, 0.621) &= \\ \textup{Gain}(D, 密度, 0.636) &= \\ \textup{Gain}(D, 密度, 0.648) &= \\ \textup{Gain}(D, 密度, 0.661) &= \\ \textup{Gain}(D, 密度, 0.681) &= \\ \textup{Gain}(D, 密度, 0.708) &= \\ \textup{Gain}(D, 密度, 0.746) &= \\ \end{align}\]

按书上的结果,应该就是取0.381这个划分点的信息增益。

$\textup{Gain}(D, 密度, 0.381)$ 的计算代码:

ent_x = -1 * 13/17 * (8/13 * math.log2(8/13) + 5/13 * math.log2(5/13))
print("res=", round(ent_D - ent_x, 3))