当前位置: 代码迷 >> 综合 >> 《机器学习实战》CART回归树源码问题:TypeError: list indices must be integers or slices, not tuple
  详细解决方案

《机器学习实战》CART回归树源码问题:TypeError: list indices must be integers or slices, not tuple

热度:67   发布时间:2023-12-26 17:40:52.0

书中代码1:

def binSplitDataSet(dataSet, feature, value):mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]return mat0,mat1

改成:

def binSplitDataSet(dataSet, feature, value):  featList = []mat0 = []mat1 = []for featVec in dataSet:featList.append(featVec[feature])for feat in featList:if feat > value:mat0.append(dataSet[featList.index(feat)])else:mat1.append(dataSet[featList.index(feat)])return mat0, mat1

书中代码2:

def regLeaf(dataSet):return mean(dataSet[:,-1])

改成:

def regLeaf(dataSet):valueList = []for featVec in dataSet:valueList.append(featVec[-1])return mean(valueList)

书中代码3:

def regErr(dataSet):return var(dataSet[:,-1]) * shape(dataSet)[0]

改成:

def regErr(dataSet):valueList = []for featVec in dataSet:valueList.append(featVec[-1])var = 0mean = sum(valueList)/len(valueList)for value in valueList:var += (mean-value)**2return var/len(valueList) * shape(dataSet)[0]

书中代码4:

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):tolS = ops[0]; tolN = ops[1]if len(set(dataSet[:,-1].T.tolist()[0])) == 1:return None, leafType(dataSet)m,n = shape(dataSet)S = errType(dataSet)bestS = inf; bestIndex = 0; bestValue = 0for featIndex in range(n-1):for splitVal in set(dataSet[:,featIndex]):mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continuenewS = errType(mat0) + errType(mat1)if newS < bestS: bestIndex = featIndexbestValue = splitValbestS = newSif (S - bestS) < tolS: return None, leafType(dataSet) #exit cond 2mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3return None, leafType(dataSet)return bestIndex,bestValue

改成:

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):tolS = ops[0]  tolN = ops[1]  valueList = []for featVec in dataSet:valueList.append(featVec[-1])if len(list(set(valueList))) == 1:return None, leafType(dataSet)m, n = shape(dataSet)S = errType(dataSet)  bestS = inf  bestIndex = 0bestValue = 0for featIndex in range(n - 1):valueList = []for featVec in dataSet:valueList.append(featVec[featIndex])for splitVal in list(set(valueList)):mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continuenewS = errType(mat0) + errType(mat1)if newS < bestS:bestIndex = featIndexbestValue = splitValbestS = newSif (S - bestS) < tolS:return None, leafType(dataSet)mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):return None, leafType(dataSet)return bestIndex, bestValue

运行结果:

 

  相关解决方案