当前位置: 代码迷 >> 综合 >> Use matplotlib draw the tree
  详细解决方案

Use matplotlib draw the tree

热度:79   发布时间:2023-11-29 16:48:24.0

Plotting the tree in Python with Matplotlib annotations

Unfortunately,Python does’t include a good tool for plotting trees. so we’ll make our own.

这才是真正的工程师精神

Matplotlib has a great tool ,called annotations.that can add text near data in a plot.

1.Plotting trees nodes with text annotations

利用文字注释功能来画树结点

import matplotlib.pyplot as plt# define nodeType 叶结点,判别结点,箭头类型的定义decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")# 定义结点函数def plotNode(nodeText,centerPt,parentPt,nodeType):createPlot.ax1.annotate(nodeText,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=nodeType,arrowprops=arrow_args)# 这个参数多的有点恐怖.没有看懂# 创造一个绘制图def createPlot():fig=plt.figure(1,facecolor='white')fig.clf()                 #create a new figure and clear it 将新建的画板进行清理 createPlot.ax1=plt.subplot(111,frameon=False)plotNode('a decision node ',(0.5,0.1),(0.1,0.5),decisionNode)plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)plt.show()
# 调用函数,将绘制的图在屏幕上显示createPlot()

[createplot

2.A strategy for plotting tree

Identifying the number of leaves in a tree and the depth

Need to know how many leafnodes and how many levels to decide the properly size the X,Y direction.

# Numleafs function
def getNumLeafs(myTree):numLeafs=0firstList=list(myTree.keys())firstStr=firstList[0]secondDict=myTree[firstStr]# 读取键值的valuefor key in secondDict.keys():# 监测是否有还有字典集合if type(secondDict[key]).__name__=='dict':numLeafs+=getNumLeafs(secondDict[key])else: numLeafs+=1return numLeafs# depths function
def getTreeDepth(myTree):maxDepth=0firstList=list(myTree.keys())firstStr=firstList[0]secondDict=myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':thisDepth=1+getTreeDepth(secondDict[key])else: thisDepth=1if thisDepth>maxDepth:maxDepth=thisDepthreturn maxDepth
需要注意的是这里有 Python版本的问题
Python 2 中 firstStr=myTree.keys()[0]
Python 3 中 firstList=list(firstStr)firstStr=firstList[0]这个函数的目的是将字典的第一个键值进行读取,
# make a tree data
def retrieveTree(i):listOfTrees =[{
   'no surfacing': {
   0: 'no', 1: {
   'flippers': {
   0: 'no', 1: 'yes'}}}},{
   'no surfacing': {
   0: 'no', 1: {
   'flippers': {
   0: {
   'head': {
   0: 'no', 1: 'yes'}}, 1: 'no'}}}}]return listOfTrees[i]
retrieveTree(0)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
#运行函数查看
getNumLeafs(retrieveTree(1))
4
# Plots text between child and parent
def plotMidText(cntrPt,parentPt,txtString):xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]createPlot.ax1.text(xMid,yMid,txtString)# define the main functions, plotTree
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split onnumLeafs = getNumLeafs(myTree)  #this determines the x width of this treedepth = getTreeDepth(myTree)firstList = list(myTree.keys())firstStr=firstList[0] #the text label for this node should be thiscntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key))        #recursionelse:   #it's a leaf node print the leaf nodeplotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict
# 进行图形显示
def createPlot(inTree):fig=plt.figure(1,facecolor='white')fig.clf()axprops=dict(xticks=[],yticks=[])createPlot.ax1=plt.subplot(111,frameon=False,**axprops)plotTree.totalW=float(getNumLeafs(inTree))plotTree.totalD=float(getTreeDepth(inTree))plotTree.xOff=-0.5/plotTree.totalWplotTree.yOff=1.0plotTree(inTree,(0.5,1.0),'')plt.show()
# 调用函数进行来完成树的绘制
createPlot(retrieveTree(0))

plotTree

createPlot(retrieveTree(1))

png
otherTree

3.Put our decision tree code to use on some real data

# classification function for an existing decision treedef classify(inputTree,featLabels,testVec):firstList=list(inputTree.keys())firstStr=firstList[0]secondDict=inputTree[firstStr]featIndex=featLabels.index(firstStr)for key in secondDict.keys():if testVec[featIndex]==key:if type(secondDict[key]).__name__=='dict':classLabel=classify(secondDict[key],featLabels,testVec)else:classLabel=secondDict[key]return classLabels
利用pickle 来进行序列化 serialize objects allow us to store them for later use

def storeTree(inputTree,filename):import picklefw = open(filename,'w')pickle.dump(inputTree,fw)fw.close()
def grabTree(filename):import picklefr = open(filename)return pickle.load(fr)

4.persisting the decision tree

# methods for persisting the decision tree with pickle
def storeTree(inPutTree,filename):import picklefw=open(filename,'w')pickle.dump(inPutTree,fw)fw.close()def grabTree(filename):import picklefr=open(filename)return pickle.load(fr)

Summary

最主要的还是掌握C4.5 和CART 算法的过程,详细见西瓜书,周志华.还有就是剪枝处理,连续值余缺失值的处理.
代码的实现过程,只是一个将理论转换成实际的过程,我觉得代码可以用的可以直接 import.
  相关解决方案