决策树ID3简单实现

it2025-03-31  22

from math import log # 初始化数据集 def createDataSet(): dataSet=[ [1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no'], ] labels = ["without water", "flippers"] return dataSet, labels # 计算数据集结点的熵增 def calShannonEnt(dataset): labelCount = {} numEntries = len(dataset) for fratVec in dataset: label = fratVec[-1] if label not in labelCount.keys(): labelCount[label]=0 labelCount[label]+=1 shanNonEnt = 0.0 for key in labelCount: prob = float(labelCount[key])/numEntries shanNonEnt = -prob*log(prob,2) return shanNonEnt # 切割数据集 def splitDataSet(dataset, axis, value): retDataSet = [] for featVec in dataset: if(featVec[axis]==value): reduceFeatVec = featVec[:axis] reduceFeatVec.extend(featVec[axis+1:]) retDataSet.append(reduceFeatVec) return retDataSet # 选择信息增益最大的特征 def chooseBestFeatureSplit(dataset): bestFeature = -1 bestInfo = 0.0 featureCount = len(dataset[0]) - 1 baseEntrpy = calShannonEnt(dataset) for i in range(featureCount): featureList = [example[i] for example in dataset] uniqueVal = set(featureList) newEntrop = 0.0 for value in uniqueVal: # 分开算正例和反例,然后加起来 subDataSet = splitDataSet(dataset,i,value) prob = len(subDataSet)/float(len(dataset)) newEntrop+=prob*calShannonEnt(subDataSet) infoGain = baseEntrpy - newEntrop # 得到某特征信息增益 # 比较看哪个特征的信息增益最大 if(infoGain>bestInfo): bestInfo = infoGain bestFeature = i # 返回最优特征值 return bestFeature # 递归的创建树 def createTree(dataset,labels): classList = [example[-1] for example in dataset] if classList.count(classList[0]) == len(classList): # 类别完全相同则停止划分 return classList[0] bestFeature = chooseBestFeatureSplit(dataset) bestFeatLabel = labels[bestFeature] myTree = {bestFeatLabel: {}} del (labels[bestFeature]) featValues = [example[bestFeature] for example in dataset] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] # 列出剩余的所有特征 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeature,value), subLabels) # 类似汉诺塔,数据集渐渐变小,从不断变小的数据集中选择最优特征 return myTree # 测试主函数 if __name__ == '__main__': dataset,label = createDataSet() print(createTree(dataset,label))

 

最新回复(0)