当前位置: 首页 > Machine Learning, ML-算法 > 正文

机器学习实战03 – 决策树分类器

机器学习 实战 决策树 代码实现, 以及利用graphviz和pygraphviz库画树形图,类似下面的图形

并且根据生成的决策树给数据分类.

涉及到 计算Shannon 熵, 设 $x_i$ 为训练数据中的一个分类, $S$ 表示训练数据集类别的样本空间, 则 $p(i)$表示类别$x_i$ 在样本空间 $S$中出现的频率 则, Shannon 熵H的计算公式为

$$H = – \sum_{j=1}^{n}p(x_j)log_{2}^{p(x_j)}$$

file

NOTE: 这里是一种优化的决策树可视化方法.

#!/usr/bin/env python3

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from math import log
import operator
import pygraphviz as pgv
import uuid

def main():
 myDat, labels = createDataSet()
 myDat2, labels2 = createDataSet_2()

 print(np.array(myDat))
 print(type(myDat))
 print(chooseBestFeatureToSplit(myDat))
 tree = createTree(myDat, labels[:])
 print(tree)
 tree2 = createTree(myDat2, labels2[:])
 print(tree2)
 picFileName = 'file.png'
 generatePicForTree(picFileName, tree)

 testVet = [0, 2]
 testClass = classify(tree, labels, testVet)
 print(f"testClass = {testClass}")

 """
 读取 lenses.txt文件中以tab符分割的训练数据, 绘制由训练数据生成的决策树模型
 """
 lensData, lensDataLabels = createLensesDataSet()
 lensTree = createTree(lensData, lensDataLabels[:])
 print(lensTree)
 picFileName = 'lenses.png'
 generatePicForTree(picFileName, lensTree)
 testVet = ['pre', 'myope', 'no', 'normal'] # 测试数据, 输入决策树 获取分类
 classForTest = classify(lensTree, lensDataLabels, testVet) # no lenses
 print(f"classForTest={classForTest}")

def calcShannonEnt(dataSet):
 """
 计算Shannon 熵
 """
 numEntries = len(dataSet)
 labelCounts = {}
 for featVec in dataSet:
 currentLabel = featVec[-1]
 if currentLabel not in labelCounts.keys():
 labelCounts[currentLabel] = 0
 labelCounts[currentLabel] += 1
 shannonEnt = 0.0
 for key in labelCounts:
 prob = float(labelCounts[key]) / numEntries
 shannonEnt -= prob * log(prob, 2)
 return shannonEnt

def splitDataset(dataSet, axis, value):
 retDataSet = []
 for featVec in dataSet:
 if featVec[axis] == value:
 reducedFeatVec = featVec[:axis]
 reducedFeatVec.extend(featVec[axis + 1:]) # 注意这里 拿出来的retDataSet中不包含axis指定的那一列了
 retDataSet.append(reducedFeatVec)
 return retDataSet

def chooseBestFeatureToSplit(dataSet):
 """
 选择最好的数据集划分方式,即, 按照哪一列划分数据集,可以获得更多的信息增益
 :param dataSet:
 :return:
 """
 numFeatures = len(dataSet[0]) - 1 # 因为最后一列是 label, 不是特征
 baseEntropy = calcShannonEnt(dataSet)
 bestInfoGain = 0.0
 bestFeature = -1
 for i in range(numFeatures):
 featList = [example[i] for example in dataSet] # 简单推导(基础教程p83),是取 第 i 列的所有数据

 uniqueFeatureVals = set(featList)
 newEntropy = 0.0
 for value in uniqueFeatureVals:
 subDataSet = splitDataset(dataSet, i, value)
 prob = len(subDataSet) / float(len(dataSet))
 newEntropy += prob * calcShannonEnt(subDataSet) # 这里所有subDateSet的Shannon熵 组成了一个新的Set?
 infoGain = baseEntropy - newEntropy
 if infoGain > bestInfoGain:
 bestInfoGain = infoGain
 bestFeature = i
 return bestFeature

def majorityCnt(classList):
 classCount = {}
 for vote in classList:
 if vote not in classCount.keys():
 classCount[vote] = 0
 classCount[vote] += 1
 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
 return sortedClassCount[0][0]

def createTree(dataSet, labels):
 """

 :param dataSet:
 :param labels: 是每个feature的标签,相当于column name or header
 :return:
 """
 classList = [example[-1] for example in dataSet] # 获取 最后一列(分类那一列)
 if classList.count(classList[0]) == len(classList): # 第一个分类的数量 跟分类总数相同, 表示仅有一个分类
 return classList[0]
 if len(dataSet[0]) == 1: # 当前已经是最后一个feature了, 则返回classList中 出现次数最多的那个
 return majorityCnt(classList)
 bestFeature = chooseBestFeatureToSplit(dataSet)
 bestFeatureLabel = labels[bestFeature]
 myTree = {bestFeatureLabel: {}}
 del (labels[bestFeature])
 featureVals = [example[bestFeature] for example in dataSet]
 uniqueValues = set(featureVals)
 for value in uniqueValues:
 subLabels = labels[:] # 返回labels的克隆
 myTree[bestFeatureLabel][value] = createTree(splitDataset(dataSet, bestFeature, value), subLabels)
 return myTree

def generatePicForTree(picFileName, theTree):
 """
 将决策树(dict类型)可视化,保存在由picFileName指定的文件中
 :param picFileName:
 :param theTree:
 :return:
 """
 G = pgv.AGraph(directed=True, rankdir='UD')
 G.graph_attr['epsilon'] = '0.001'
 buildTreeGraph(theTree, None, G)
 G.layout('dot')
 G.draw(picFileName)

def buildTreeGraph(myTree, parent, theGraph):
 """
 使用 pygraphviz 库(底层依赖 graphviz) 画树形图
 :param myTree: 根据训练数据生成的决策树,是dict类型
 :param parent: 当前处理的 myTree的根节点,是graphviz 中的node的ID
 :param theGraph: 传入的pygraphviz 库中的AGraph 对象
 :return: 没有返回值
 """
 currentGraph = theGraph
 for k in myTree.keys():
 v = myTree[k]
 keyNodeId = uuid.uuid1()
 currentGraph.add_node(keyNodeId, label=k)
 if parent:
 currentGraph.add_edge(parent, keyNodeId)
 if isinstance(v, dict):
 buildTreeGraph(v, keyNodeId, currentGraph)
 else:
 valueNodeId = uuid.uuid1()
 currentGraph.add_node(valueNodeId, label=v)
 currentGraph.add_edge(keyNodeId, valueNodeId)

def classify(inputTree, featLabels, testVec):
 """
 inputTree 类似于 {'有胡子': {0: {'长头发': {0: '女', 1: '女'}}, 1: '男'}} 这样的 树形结构,
 其中labels是特征列的标签,相当于列名labels = ['有胡子', '长头发'], inputTree中的keys都是labels中的值之一
 inputTree中的value
 :param inputTree: 类似于 {'有胡子': {0: {'长头发': {0: '女', 1: '女'}}, 1: '男'}} 这样的 树形结构
 :param featLabels: labels是特征列的标签,相当于列名labels = ['有胡子', '长头发']
 :param testVec: 测试数据 [1,0] 这样跟 ['有胡子', '长头发'] 对应的一行特征数据
 :return:
 """
 firstStr = list(inputTree.keys())[0]
 secondDict = inputTree[firstStr]
 featIndex = featLabels.index(firstStr)
 for key in secondDict.keys():
 if testVec[featIndex] == key:
 if type(secondDict[key]).__name__ == 'dict':
 classLabel = classify(secondDict[key], featLabels, testVec)
 else:
 classLabel = secondDict[key]
 else:
 classLabel = 'Unknown' # 当测试数据中出现了训练数据中没有出现的特着值时
 return classLabel

def createLensesDataSet():
 fr = open("lenses.txt")
 lenses = [inst.strip().split('\t') for inst in fr.readlines()]
 lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
 return lenses, lensesLabels

def createDataSet_2():
 dataSet = [[1, 1, 'yes'],
 [1, 1, 'yes'],
 [1, 0, 'no'],
 [0, 1, 'no'],
 [0, 1, 'no']]
 labels = ['no surfacing', 'flippers']
 return dataSet, labels

def createDataSet():
 dataSet = [[1, 0, '男'],
 [1, 0, '男'],
 [1, 1, '男'],
 [0, 1, '女'],
 [0, 1, '男'],
 [0, 0, '女']]
 labels = ['有胡子', '长头发']
 return dataSet, labels

if __name__ == '__main__':
 main()

file

赞 赏

   微信赞赏  支付宝赞赏


本文固定链接: https://www.jack-yin.com/coding/ml/2821.html | 边城网事

该日志由 边城网事 于2019年08月09日发表在 Machine Learning, ML-算法 分类下, 你可以发表评论,并在保留原文地址及作者的情况下引用到你的网站或博客。
原创文章转载请注明: 机器学习实战03 – 决策树分类器 | 边城网事

机器学习实战03 – 决策树分类器:目前有1 条留言

发表评论

快捷键:Ctrl+Enter