记录下决策树的原理和简单Demo;
文章目录
一、决策树原理二、决策树优缺点三、决策树算法种类四、信息增益五、Demo
一、决策树原理
决策树能够生成清晰的基于特征选择不同预测结果的树状结构;希望更好的理解手上的数据的时候,往往可以使用决策树。 在实际应用中,受限于它的简单性,决策树更大的用处是作为一些更有用的算法的基石,例如随机森林。
二、决策树优缺点
优点:计算复杂度不高,输出结果易于理解;对中间值的缺失不敏感;可以处理不相关的特征数据;缺点:可能会产生过度匹配问题;适用数据类型:数值型、标称型;
三、决策树算法种类
ID3:以信息增益作为树的分裂准则;C4.5:以基于信息增益的增益率作为树的分裂准则,解决了ID3的偏向与多值属性问题;CART:ID3和C4.5只能处理分类问题,而CART可以处理分类和回归问题;
四、信息增益
信息增益用于度量一个随机变量中包含的关于另一个随机变量的信息量,或者说是一个随机变量由于另一个随机变量而减少的不肯定性,也可以简单认为一个随机变量的引入导致了另一个随机变量的混乱性变化(约束)。
g(D,A) = H(D) - H(D|A)
对于决策树来说,信息增益越大,特征对最终的分类结果影响也越大。
五、Demo
from math
import log
import operator
"""
Parameters:
无
Returns:
dataSet - 数据集
labels - 特征标签
"""
def createDataSet():
dataSet
= [[0, 0, 0, 0, 'no'],
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels
= ['年龄', '有工作', '有自己的房子', '信贷情况']
return dataSet
, labels
"""
Parameters:
dataSet - 数据集
Returns:
shannonEnt - 经验熵(香农熵)
"""
def calcShannonEnt(dataSet
):
numEntires
= len(dataSet
)
labelCounts
= {}
for featVec
in dataSet
:
currentLabel
= featVec
[-1]
if currentLabel
not in labelCounts
.keys
():
labelCounts
[currentLabel
] = 0
labelCounts
[currentLabel
] += 1
shannonEnt
= 0.0
for key
in labelCounts
:
prob
= float(labelCounts
[key
]) / numEntires
shannonEnt
-= prob
* log
(prob
, 2)
return shannonEnt
"""
Parameters:
dataSet - 待划分的数据集
axis - 划分数据集的特征
value - 需要返回的特征的值
Returns:
无
"""
def splitDataSet(dataSet
, axis
, value
):
retDataSet
= []
for featVec
in dataSet
:
if featVec
[axis
] == value
:
reducedFeatVec
= featVec
[:axis
]
reducedFeatVec
.extend
(featVec
[axis
+1:])
retDataSet
.append
(reducedFeatVec
)
return retDataSet
"""
Parameters:
dataSet - 数据集
Returns:
bestFeature - 信息增益最大的(最优)特征的索引值
"""
def chooseBestFeatureToSplit(dataSet
):
numFeatures
= len(dataSet
[0]) - 1
baseEntropy
= calcShannonEnt
(dataSet
)
bestInfoGain
= 0.0
bestFeature
= -1
for i
in range(numFeatures
):
featList
= [example
[i
] for example
in dataSet
]
uniqueVals
= set(featList
)
newEntropy
= 0.0
for value
in uniqueVals
:
subDataSet
= splitDataSet
(dataSet
, i
, value
)
prob
= len(subDataSet
) / float(len(dataSet
))
newEntropy
+= prob
* calcShannonEnt
(subDataSet
)
infoGain
= baseEntropy
- newEntropy
if (infoGain
> bestInfoGain
):
bestInfoGain
= infoGain
bestFeature
= i
return bestFeature
"""
Parameters:
classList - 类标签列表
Returns:
sortedClassCount[0][0] - 出现此处最多的元素(类标签)
"""
def majorityCnt(classList
):
classCount
= {}
for vote
in classList
:
if vote
not in classCount
.keys
():classCount
[vote
] = 0
classCount
[vote
] += 1
sortedClassCount
= sorted(classCount
.items
(), key
= operator
.itemgetter
(1), reverse
= True)
return sortedClassCount
[0][0]
"""
Parameters:
dataSet - 训练数据集
labels - 分类属性标签
featLabels - 存储选择的最优特征标签
Returns:
myTree - 决策树
"""
def createTree(dataSet
, labels
, featLabels
):
classList
= [example
[-1] for example
in dataSet
]
if classList
.count
(classList
[0]) == len(classList
):
return classList
[0]
if len(dataSet
[0]) == 1:
return majorityCnt
(classList
)
bestFeat
= chooseBestFeatureToSplit
(dataSet
)
bestFeatLabel
= labels
[bestFeat
]
featLabels
.append
(bestFeatLabel
)
myTree
= {bestFeatLabel
:{}}
del(labels
[bestFeat
])
featValues
= [example
[bestFeat
] for example
in dataSet
]
uniqueVals
= set(featValues
)
for value
in uniqueVals
:
myTree
[bestFeatLabel
][value
] = createTree
(splitDataSet
(dataSet
, bestFeat
, value
), labels
, featLabels
)
return myTree
if __name__
== '__main__':
dataSet
, labels
= createDataSet
()
featLabels
= []
myTree
= createTree
(dataSet
, labels
, featLabels
)
print(myTree
)