1.算法原理
决策树是一棵树,它的每个节点都是一次决策,该节点的子树分别代表不同的决策,叶子节点表示所有数据已经属于同一类型,无法再分。 因此构造决策树只需要做一件事,找出划分当前数据集的最优特征,之后递归子树即可把决策树构造出来。
1.1 找出最优特征
那如何找出最优特征呢,可以从信息论的方向出发,在划分数据前后使用信息论量化度量信息的内容,选取信息增益最高的特征作为当前的选择。
熵定义为信息的期望值,因此需要找熵最大的划分。
这里采用ID3算法去计算熵
总结一下,我们需要对数据集的每个特征都尝试划分一次,并计算熵,最后选取使得熵最大的特征划分为当前数据集的划分。
1.2 递归子树
递归的终止条件有两个:
每个分支下的所有实例都具有相同分类遍历完所有划分数据集的属性
情况1,所有实例都具有相同分类,这就是一个叶子节点。 情况2,可以采用多数表决的方法,将标签出现频率最高的做为此时的分类
2. 代码
import operator
from math
import log
import pickle
"""
函数说明:
计算数据集的香农熵
公式:H = - ( for i in range(n): p(xi) * log(p(xi),2) )
参数:
dataSet: 数据集
返回值:
shannonEnt: 香农熵
"""
def calcShannonEnt (dataSet
):
numEntries
= len(dataSet
)
labelCounts
= {}
for featVec
in dataSet
:
currentLabel
= featVec
[-1]
if currentLabel
not in labelCounts
.keys
():
labelCounts
[currentLabel
] = 0
labelCounts
[currentLabel
] += 1
shannonEnt
= 0.0
for key
in labelCounts
:
prob
= float(labelCounts
[key
]) / numEntries
shannonEnt
-= prob
* log
(prob
, 2)
return shannonEnt
"""
函数说明:
划分数据集
参数:
dataSet: 数据集
axis: 需要去掉特征的索引值
value: 需要返回特征值
返回值:
retDataSet: 返回划分后的结果集
"""
def splitDataSet(dataSet
, axis
, value
):
retDataSet
= []
for featVec
in dataSet
:
if featVec
[axis
] == value
:
reducedFeatVec
= featVec
[:axis
]
reducedFeatVec
.extend
(featVec
[axis
+1:])
retDataSet
.append
(reducedFeatVec
)
return retDataSet
"""
函数说明:
选择最优的数据集划分方式
参数:
dataset:数据集
返回值:
bestFeature:最优特征的索引值
"""
def chooseBestFeatureToSplit(dataSet
):
numFeatures
= len(dataSet
[0]) - 1
baseEntropy
= calcShannonEnt
(dataSet
)
bestInfoGain
= 0.0
bestFeature
= -1
for i
in range(numFeatures
):
featList
= [example
[i
] for example
in dataSet
]
uniqueVals
= set(featList
)
newEntropy
= 0.0
for value
in uniqueVals
:
subDataSet
= splitDataSet
(dataSet
, i
, value
)
prob
= len(subDataSet
) / float(len(dataSet
))
newEntropy
+= prob
* calcShannonEnt
(subDataSet
)
infoGain
= baseEntropy
- newEntropy
if (infoGain
> bestInfoGain
):
bestInfoGain
= infoGain
bestFeature
= i
return bestFeature
"""
函数说明:
测试熵的自建数据集
参数:
无
返回值:
dataSet: 数据集
labels: 标签
"""
def createDataSet():
dataSet
= [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels
= ['no surfacing', 'flippers']
return dataSet
, labels
"""
函数说明:
统计classList中出现次数最多的元素(类标签)
服务于递归第两个终止条件
参数:
classList:类标签列表
返回值:
sortedClassCount[0][0]:出现次数最多的元素(类标签)
"""
def majorityCnt(classList
):
classCount
= {}
for vote
in classList
:
if vote
not in classCount
.keys
():
classCount
[vote
] = 0
classCount
[vote
] += 1
sortedClassCount
= sorted(classCount
.items
(), key
=operator
.itemgetter
(1), reverse
=True)
return sortedClassCount
[0][0]
"""
函数说明:
创建决策树(ID3算法)
参数:
dataSet: 数据集
labels: 标签
返回值:
myTree: 决策树(字典表示)
"""
def createTree(dataSet
, labels
):
classList
= [example
[-1] for example
in dataSet
]
if classList
.count
(classList
[0]) == len(classList
):
return classList
[0]
if len(dataSet
[0]) == 1:
return majorityCnt
(classList
)
bestFeat
= chooseBestFeatureToSplit
(dataSet
)
bestFeatLabel
= labels
[bestFeat
]
myTree
= {bestFeatLabel
: {}}
del(labels
[bestFeat
])
featValues
= [example
[bestFeat
] for example
in dataSet
]
uniqueVals
= set(featValues
)
for value
in uniqueVals
:
subLabels
= labels
[:]
myTree
[bestFeatLabel
][value
] = createTree
(splitDataSet
(dataSet
, bestFeat
, value
), subLabels
)
return myTree
"""
函数说明:
使用决策树分类
参数:
inputTree - 已经生成的决策树
featLabels - 存储选择的最优特征标签
testVec - 测试数据列表,顺序对应最优特征标签
返回值:
classLabel - 分类结果
"""
def classify(inputTree
, featLabels
, testVec
):
firstStr
= inputTree
.keys
()[0]
secondDict
= inputTree
[firstStr
]
featIndex
= featLabels
.index
(firstStr
)
for key
in secondDict
.key
():
if testVec
[featIndex
] == key
:
if type(secondDict
[key
]).__name__
== 'dict':
classLabel
= classify
(secondDict
[key
], featLabels
, testVec
)
else:
classLabel
= secondDict
[key
]
return classLabel
def storeTree(inputTree
, filename
):
fw
= open(filename
, 'w')
pickle
.dump
(inputTree
, fw
)
fw
.close
()
def grabTree(filename
):
fr
= open(filename
)
return pickle
.load
(fr
)
def main():
fr
= open('lenses.txt')
lenses
= [inst
.strip
().split
('\t') for inst
in fr
.readlines
()]
lensesLabels
= ['age', 'prescript', 'astigmatic', 'tearRate']
mTree
= createTree
(lenses
, lensesLabels
)
print(mTree
)
if __name__
== '__main__':
main
()