from math import log
# 初始化数据集
def createDataSet():
dataSet=[
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no'],
]
labels = ["without water", "flippers"]
return dataSet, labels
# 计算数据集结点的熵增
def calShannonEnt(dataset):
labelCount = {}
numEntries = len(dataset)
for fratVec in dataset:
label = fratVec[-1]
if label not in labelCount.keys():
labelCount[label]=0
labelCount[label]+=1
shanNonEnt = 0.0
for key in labelCount:
prob = float(labelCount[key])/numEntries
shanNonEnt = -prob*log(prob,2)
return shanNonEnt
# 切割数据集
def splitDataSet(dataset, axis, value):
retDataSet = []
for featVec in dataset:
if(featVec[axis]==value):
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
retDataSet.append(reduceFeatVec)
return retDataSet
# 选择信息增益最大的特征
def chooseBestFeatureSplit(dataset):
bestFeature = -1
bestInfo = 0.0
featureCount = len(dataset[0]) - 1
baseEntrpy = calShannonEnt(dataset)
for i in range(featureCount):
featureList = [example[i] for example in dataset]
uniqueVal = set(featureList)
newEntrop = 0.0
for value in uniqueVal:
# 分开算正例和反例,然后加起来
subDataSet = splitDataSet(dataset,i,value)
prob = len(subDataSet)/float(len(dataset))
newEntrop+=prob*calShannonEnt(subDataSet)
infoGain = baseEntrpy - newEntrop
# 得到某特征信息增益
# 比较看哪个特征的信息增益最大
if(infoGain>bestInfo):
bestInfo = infoGain
bestFeature = i
# 返回最优特征值
return bestFeature
# 递归的创建树
def createTree(dataset,labels):
classList = [example[-1] for example in dataset]
if classList.count(classList[0]) == len(classList): # 类别完全相同则停止划分
return classList[0]
bestFeature = chooseBestFeatureSplit(dataset)
bestFeatLabel = labels[bestFeature]
myTree = {bestFeatLabel: {}}
del (labels[bestFeature])
featValues = [example[bestFeature] for example in dataset]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
# 列出剩余的所有特征
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeature,value), subLabels)
# 类似汉诺塔,数据集渐渐变小,从不断变小的数据集中选择最优特征
return myTree
# 测试主函数
if __name__ == '__main__':
dataset,label = createDataSet()
print(createTree(dataset,label))