深入理解决策树模型

在前一篇文章中,我们简单了解了决策树分类模型。从算法的伪代码中,我们看到决策树有3个基本操作:

  1. 判断停止条件 stop_critera
  2. 选择分裂的属性和划分的阈值,将数据集一分为二 select_attr_and_spliter
  3. 递归地对这两个子集使用决策树生成算法
# 决策树算法伪代码
def generate_decision_tree(D):
    T = TreeNode()  # 决策树

    # 停止条件
    if stop_critera(D):
        return T

    # 选择一个属性和对应的划分阈值,将数据集D分为D1和D2两个子集
    attr, spliter, D1, D2 = select_attr_and_spliter(D)  
    T.value = (attr, spliter)

    # 左子树
    T1 = generate_decision_tree(D1)
    # 右子树
    T2 = generate_decision_tree(D2)

    T.left = T1
    T.right = T2
      
    return T

停止准则

在前面的文章中,我们判断如果集合D中只有1种类别,就停止分裂,这是一种常见的停止准则。此外,还有很多别的停止准则。以下是决策树算法常见的停止准则:

  1. 最小样本数:当节点中的样本数量低于设定的阈值时,停止分裂。
  2. 纯度达到一定程度:如节点中的样本几乎都属于同一类别,达到很高的纯度时停止。
  3. 达到最大深度:规定决策树生长的最大深度,到达该深度后不再继续分裂。

sklearn中的 DecisionTreeClassifier 模型中,min_samples_leaf 是叶子节点最小样本数,对应最小样本数准则;max_depth 限定决策树的最大深度,是最大深度准则。

分裂准则

第二步是选择分裂的属性和划分的阈值,将数据集一分为二 select_attr_and_spliter。在上一篇文章中,我们通过数据分析,找到最佳属性和分裂阈值,那么怎么让算法自动找到属性和分裂阈值呢?

很简单,我们遍历每一个属性以及每个属性里的每一个可能的分裂阈值,找到分裂的价值最大的那个属性和阈值即可(如下面的伪代码所示)。

下面是寻找分裂属性和阈值的算法,两个for训练依次遍历属性和分裂阈值,get_all_values也就是返回属性attr所有不相同的数值,即对每个值都尝试做分裂阈值。

calc_split_value是计算按照属性attr,阈值threshold 将数据集D划分成 attr<=threshold 和 attr > threshold的两个集合的价值。但是,这里面有个问题,怎么评估分离的价值?也就是算法中的 calc_split_value 函数如何设计?这个我们后面再说。

最后,当我们找到最佳的分裂属性和阈值后,就可以将数据集D一分为二,D1是满足属性attr<=threshold的数据,D2是剩下的数据。

def find_spliter_attr_and_threshold(D):
    # 记录最佳分裂的价值、属性、分裂阈值
    max_value, max_attr, max_threshold = -inf, None, None

    for attr in  D.attr_list:
        for threshold in get_all_values(D[attr]):
            # 计算用属性attr,阈值threshold分裂,的价值
            value = calc_split_value(D, D[attr], threshold)

            if value > max_value:
                max_value, max_attr, max_threshold = value, attr, value
    
    D1, D2 = split(D, D[max_attr], max_threshold)
    return max_attr, max_threshold, D1, D2

那么,怎么评估分裂的价值呢?或者说,我们通过什么样的一个指标,来评估将数据集D划分成D1和D2后,对我们最终分类效果的价值。

既然我们期望分裂的目标是分类效果,那么我们就可以直接用分类效果来评估。我们假设D1和D2是决策树最终的叶子节点,那么我们就认为D1预测的类别是D1中样本最多的类别。

假设 D1 中三种类别的数目分别是 [9, 1, 0],那么D1预测成数目最多的第一种类别,那么9个分类正确,1个分类错误,准确率是r1=90%。我们定义叶子节点的准确率=分类正确的数目/总的样本数目。

同样我们也可以计算出D2的准确率,假设为r2=80%。那么,我们可以计算平均准确率为:

$$ |D1| \times r1 + |D2| \times r2 $$

|D1|代表集合D1中样本数目。于是,我们有以下计算分裂价值的算法

def calc_split_value(D, D[attr], threshold):
    # 按照属性attr将D划分成两个子集
    D1 = [d for d in D if d[attr] <= threshold]
    D2 = [d for d in D if d[attr] > threshold]

    # 计算D1的准确率
    t1 = mode([d.target for d in D1])  # 计算D1中类别的众数,作为D1集合预测结果
    r1 = avg([d.target == t1 for d in D1])  # 计算D1集合的准确率

    # 计算D2的准确率
    t2 = mode([d.target for d in D2]) 
    r1 = avg([d.target == t2 for d in D2])

   # 计算平均准确率
   r = (r1 * size(D1) + r2 * size(D2))/size(D)
   
   return r

其他的一些准则还有信息增益,gini系数等,信息增益是来衡量分裂前后信息量的提升情况,gini系数则是衡量每个集合的纯度,属于同一类比例越高,gini系数越大。

它们都是用不同的方法来衡量分裂集合的某种价值的手段。想了解的可以阅读相关书籍和文档进一步了解。

sklearn中的 DecisionTreeClassifier 模型中,参数 criterion 用来指定使用哪种准则,默认值是gini系数,我们通常使用默认值即可。

动手实践:手写数字识别

好了,到目前为止,你已学会了决策树模型生成算法的所有细节了。接下来我们尝试使用sklearn中的 DecisionTreeClassifier 模型来解决一些实际问题吧。

sklearn中的sklearn.datasets.load_digits 包含了一个手写数字数据集。里面有1797个样本,每个样本是64维向量,对应一个8×8的图片,比如下面是其中一个样本的图片。这个样本的64维是图中每个像素点的亮度值,值越大表示越亮。现在,请动手实现决策树模型来识别这些数字。

好了,到现在为止,你已经学会决策树的原理,以及怎么使用决策树模型去做分类了。接下来你可以尝试用这个模型去做任何要分类的任务了,展开你的想想吧!

0 0 投票数
文章评分
订阅评论
提醒
guest
0 评论
内联反馈
查看所有评论
0
希望看到您的想法,请您发表评论x