import numpy
as np
def loadDataSet(fileName
):
dataMat
= []
fr
= open(fileName
)
for line
in fr
.readlines
():
curLine
= line
.strip
().split
('\t')
fltLine
= list(map(float,curLine
))
dataMat
.append
(fltLine
)
return dataMat
def binSplitDataSet(dataSet
,feature
,value
):
mat0
= dataSet
[np
.nonzero
(dataSet
[:,feature
]>value
)[0],:]
mat1
= dataSet
[np
.nonzero
(dataSet
[:,feature
]<=value
)[0],:]
return mat0
,mat1
def regLeaf(dataSet
):
return np
.mean
(dataSet
[:,-1])
def regErr(dataSet
):
return np
.var
(dataSet
[:,-1])*dataSet
.shape
[0]
def chooseBestSplit(dataSet
,leafType
= regLeaf
,errType
=regErr
,ops
=(1,4)):
"""
:param dataSet: 数据集
:param leafType: 计算叶子节点均值
:param errType: 计算方差
:param ops: tolS分裂阈值,tolN叶子节点数阈值
:return: 最好的分裂特征值和特征值
"""
tolS
= ops
[0]
tolN
= ops
[1]
if len(set(dataSet
[:,-1].T
.tolist
()[0])) == 1:
"""
dataSet[:,-1].T.tolist()转化为嵌套列表
dataSet[:,-1].T.tolist()[0]转化为只
包含最后一列标签值得列表
"""
return None,leafType
(dataSet
)
m
,n
= dataSet
.shape
S
= errType
(dataSet
)
bestS
= np
.inf
bestIndex
= 0
bestValue
= 0
for featIndex
in range(n
-1):
for splitVal
in list(np
.unique
(np
.array
(dataSet
[:,featIndex
]).reshape
(-1))):
mat0
,mat1
= binSplitDataSet
(dataSet
,featIndex
,splitVal
)
if (mat0
.shape
[0] < tolN
) or (mat1
.shape
[0]<tolN
):
continue
newS
= errType
(mat0
) + errType
(mat1
)
if newS
< bestS
:
bestIndex
= featIndex
bestValue
= splitVal
bestS
= newS
if (S
-bestS
) < tolS
:
return None,leafType
(dataSet
)
mat0
,mat1
= binSplitDataSet
(dataSet
,bestIndex
,bestValue
)
if (mat0
.shape
[0] < tolN
) or (mat1
.shape
[0] < tolN
):
return None, leafType
(dataSet
)
return bestIndex
,bestValue
def createTree(dataSet
,leafType
=regLeaf
,errType
= regErr
,ops
=(1,4)):
feat
,val
= chooseBestSplit
(dataSet
,leafType
,errType
,ops
)
if feat
== None:
return val
retTree
= {}
retTree
['spInd'] = feat
retTree
['spVal'] = val
lSet
,rSet
= binSplitDataSet
(dataSet
,feat
,val
)
retTree
['left'] = createTree
(lSet
,leafType
,errType
,ops
)
retTree
['right'] = createTree
(rSet
,leafType
,errType
,ops
)
return retTree
if __name__
== '__main__':
myDat
= loadDataSet
('ex00.txt')
myMat
= np
.mat
(myDat
)
print(createTree
(myMat
))
转载请注明原文地址: https://lol.8miu.com/read-1645.html