回归树构建

it2023-02-05  45

import numpy as np #加载数据 def loadDataSet(fileName): dataMat = [] fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split('\t') fltLine = list(map(float,curLine)) dataMat.append(fltLine) return dataMat #切分数据 def binSplitDataSet(dataSet,feature,value): mat0 = dataSet[np.nonzero(dataSet[:,feature]>value)[0],:] mat1 = dataSet[np.nonzero(dataSet[:,feature]<=value)[0],:] return mat0,mat1 #计算叶子节均值 def regLeaf(dataSet): return np.mean(dataSet[:,-1]) #计算叶子节点方差和 def regErr(dataSet): return np.var(dataSet[:,-1])*dataSet.shape[0] #选择最优的分裂特征 def chooseBestSplit(dataSet,leafType = regLeaf,errType=regErr,ops=(1,4)): """ :param dataSet: 数据集 :param leafType: 计算叶子节点均值 :param errType: 计算方差 :param ops: tolS分裂阈值,tolN叶子节点数阈值 :return: 最好的分裂特征值和特征值 """ tolS = ops[0] #方差差低于此值不分裂 tolN = ops[1] #分裂后样本数目低于此值,不分裂 if len(set(dataSet[:,-1].T.tolist()[0])) == 1: """ dataSet[:,-1].T.tolist()转化为嵌套列表 dataSet[:,-1].T.tolist()[0]转化为只 包含最后一列标签值得列表 """ return None,leafType(dataSet) #数据集只有一个分类,不分裂,直接返回叶子节点均值 m,n = dataSet.shape S = errType(dataSet) #方差和 bestS = np.inf #保留分裂后最小方差和 bestIndex = 0 #保留最优特征 bestValue = 0 #保留最优特征值 for featIndex in range(n-1): #遍历所有特征 for splitVal in list(np.unique(np.array(dataSet[:,featIndex]).reshape(-1))): #遍历某个特征某个特征值 mat0,mat1 = binSplitDataSet(dataSet,featIndex,splitVal) #分裂后样本集 if (mat0.shape[0] < tolN) or (mat1.shape[0]<tolN): #根据分裂后样本数判断是否分裂 continue newS = errType(mat0) + errType(mat1) #计算分裂后方差和 if newS < bestS: #更小方差和则保存特征,特征值,方差和 bestIndex = featIndex bestValue = splitVal bestS = newS if (S-bestS) < tolS: #分裂后方差阈值低于阈值 return None,leafType(dataSet) #不分裂 mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue) #按最优特征,特征值分裂 if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN): #分裂后样本量小于阈值不分裂 return None, leafType(dataSet) return bestIndex,bestValue #否则返回特征,特征值 def createTree(dataSet,leafType=regLeaf ,errType = regErr,ops=(1,4)): feat,val = chooseBestSplit(dataSet,leafType,errType,ops) if feat == None: #递归造树,特征为None,说明无分裂节点结束构建子树 return val retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet,rSet = binSplitDataSet(dataSet,feat,val) retTree['left'] = createTree(lSet,leafType,errType,ops) #建左子树 retTree['right'] = createTree(rSet,leafType,errType,ops) #建右子树 return retTree if __name__ == '__main__': myDat = loadDataSet('ex00.txt') myMat = np.mat(myDat) print(createTree(myMat))
最新回复(0)