当前位置: 代码迷 >> 综合 >> [Kaggle实战] Titanic 逃生预测 (4) - 决策树建模
  详细解决方案

[Kaggle实战] Titanic 逃生预测 (4) - 决策树建模

热度:58   发布时间:2023-12-08 20:51:12.0

之前的文章已经解决了数据预处理的问题。从这里开始,就要开始创建决策树了。

首先可以使用之前用Java实现的ID3算法进行修改。 之前的算法是基于Weka自带的数据进行的,跟这里的格式不太兼容。基本上需要把String改成Double就好了~

 

现在先尝试手动的创建模型,保证待会我们写出来的代码确实是正确的。

关于决策树模型以及ID3算法,具体的概念以及思路就不在这里重复写了,可以参考《数据挖掘导论》相关章节。

 

之前已经处理好的dataMatrix可以下载附件之中的train-matrix.csv. 然后直接使用Excel完成最简单的统计功能。

比如,第一步我需要统计Suvived之中1跟0的个数:



 

即: Survived=0 有549条记录, Survived=1 有342条记录。

可以使用如下代码计算熵:[代码来源:http://commons.apache.org/proper/commons-math/jacoco/org.apache.commons.math3.stat.inference/GTest.java.html 根据里面的entropy进行修改]

public static double entropy(final int[] k) {double h = 0d;double sum_k = 0d;for (int i = 0; i < k.length; i++) {sum_k += (double) k[i];}for (int i = 0; i < k.length; i++) {if (k[i] != 0) {final double p_i = (double) k[i] / sum_k;h += p_i * FastMath.log(p_i);}}return -h;
}

对于Survived, Entropy=0.9607

 

接下来就应该逐个的计算各个属性对应的熵以及对应的信息增益(Info Gain) 了

以PClass为例:



 
0.9607-(80+136)/891. * 0.9509 - (97+87)/891. * 0.9978 - (372+119)/891.*0.7989因此,Pclass属性的信息增益为

=0.0229

 

一次类推,计算Sex,Age,SibSp,Embarked对应的信息增益,结果如下:

Pclass:0.0838310452960116

Sex:0.2176601066606142

Age:0.010620040421108423

SibSp:0.022557964533659103

Embarked:0.024047090707960517

 

最终,选择Sex作为根节点。

我们看看Sex的数据情况吧,



 

 不得不说一句:女性的存活几率要比男性大得多啊!

接下来,计算第二层。 我们先计算Sex=1(male) 的情况

此时的Entropy=entropy(468,109)=0.6992 Sum = 468+109=577

详细情况:



因此,Pclass对应的InfoGain = 

0.6992 - 0.9567 * (45 + 77) / 577.0 - 0.628 * (91+17) / 577.0 - 0.5722 * (300 + 47) / 577.0

= 0.0352

后面的就不手动进行了~~~

[不得不再吐槽一下,Pclass=1的时候,生存的几率真的是非常非常大啊!不知道是不是当时的有钱人离救生艇比较近?]

具体的ID3分类器,可以参看我写的代码:https://gitcafe.com/rangerwolf/Kaggle-Titanic/blob/master/src/main/java/classifier/ID3Classifier.java。

 

运行test.MyID3.java即可得到结果。

将整个树json的格式输出出来:

{"attribute": "Sex","options": {"2.0": {"attribute": "Pclass","options": {"3.0": {"attribute": "SibSp","options": {"0.0": {"attribute": "Age","options": {},"subLeafs": {}},"1.0": {"attribute": "Age","options": {},"subLeafs": {"3.0": {"count": 5,"outputValue": 0.0,"option": 3.0}}}},"subLeafs": {}},"2.0": {"attribute": "Age","options": {"2.0": {"attribute": "SibSp","options": {},"subLeafs": {"2.0": {"count": 3,"outputValue": 1.0,"option": 2.0}}}},"subLeafs": {"1.0": {"count": 10,"outputValue": 1.0,"option": 1.0}}},"1.0": {"attribute": "Age","options": {"3.0": {"attribute": "SibSp","options": {},"subLeafs": {"1.0": {"count": 12,"outputValue": 1.0,"option": 1.0}}},"2.0": {"attribute": "SibSp","options": {},"subLeafs": {"0.0": {"count": 34,"outputValue": 1.0,"option": 0.0},"2.0": {"count": 4,"outputValue": 1.0,"option": 2.0}}}},"subLeafs": {}}},"subLeafs": {}},"1.0": {"attribute": "Pclass","options": {"3.0": {"attribute": "Age","options": {"3.0": {"attribute": "SibSp","options": {},"subLeafs": {"1.0": {"count": 2,"outputValue": 0.0,"option": 1.0}}},"2.0": {"attribute": "SibSp","options": {},"subLeafs": {}},"1.0": {"attribute": "SibSp","options": {},"subLeafs": {"1.0": {"count": 5,"outputValue": 1.0,"option": 1.0}}}},"subLeafs": {}},"2.0": {"attribute": "Age","options": {"2.0": {"attribute": "SibSp","options": {},"subLeafs": {"2.0": {"count": 4,"outputValue": 0.0,"option": 2.0}}}},"subLeafs": {"1.0": {"count": 9,"outputValue": 1.0,"option": 1.0}}},"1.0": {"attribute": "Age","options": {"3.0": {"attribute": "SibSp","options": {},"subLeafs": {}},"2.0": {"attribute": "SibSp","options": {},"subLeafs": {}}},"subLeafs": {"1.0": {"count": 3,"outputValue": 1.0,"option": 1.0}}}},"subLeafs": {}}},"subLeafs": {}
}

 

 用GUI的方式来显示json,部分结果如下:



 可以看到,大致已经有了雏形。而且可以验证的就是,至少我们的根节点是正确的。

下面是老外的成果图:(是基于Python做出来的,不过没太看懂里面的结果,感觉只有一条边有label说明~)



 

 下一篇文章,如果不出意外,将介绍一下Dot Language的应用。

后面的树状图,将会使用Dot Language以及相应的软件来进行展示。

PS:明天又要开始上班了~ 哎,可以用来学习的时间要少得多了...