通常,在网上找到的mahout的naive bayes的例子跟官网的例子,都是针对20 newsgroup. 而且通常是命令行版本。虽然能得出预测、分类结果,但是对于Bayes具体是如何工作,以及如何处理自己的数据会比较茫然。
在努力了差不多一个星期之后,终于有点成果。
这个例子就是使用mahout 0.9 对kddcup 1999 的数据进行分析。
第一步: 下载数据。
地址: http://kdd.ics.uci.edu/databases/kddcup99/
关于数据的一些简单的预处理,我们会在第二步进行。细心的你可能发现,有些数据是2007年上传的!这是因为有一些数据原来的标记有错误,后来进行了更正。
第二步: 将原始文件转换成Hadoop使用的sequence 文件。
我们从官网知道,Bayes在mahout之中只有基于map-reduce的实现。 参考: https://mahout.apache.org/users/basics/algorithms.html 所以我们必须要将csv文件转换成hadoop使用的sequence文件
先贴一下代码:(注意:这里列的代码,仅仅用于说明流程,并没有注意性能方面的考虑。处理过大的文件的时候,需要有针对性的自行进行调整~)
package experiment.kdd99_bayes;import java.io.FileReader;
import java.io.IOException;
import java.util.List;
import java.util.Map;import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;import au.com.bytecode.opencsv.CSVReader;import com.google.common.collect.Lists;
import com.google.common.collect.Maps;public class Kdd99CsvToSeqFile {private String csvPath;private Path seqPath;private SequenceFile.Writer writer;private Configuration conf = new Configuration();private Map<String, Long> word2LongMap = Maps.newHashMap();private List<String> strLabelList = Lists.newArrayList();private FileSystem fs = null;public Kdd99CsvToSeqFile(String csvFilePath, String seqPath) {this.csvPath = csvFilePath;this.seqPath = new Path(seqPath);}public Map<String, Long> getWordMap() {return word2LongMap;}public List<String> getLabelList() {return strLabelList;}/*** Show out the already sequenced file content*/public void dump() {try {fs = FileSystem.get(conf);SequenceFile.Reader reader = new SequenceFile.Reader(fs, this.seqPath, conf);Text key = new Text();VectorWritable value = new VectorWritable();while (reader.next(key, value)) {System.out.println( "reading key:" + key.toString() +" with value " +value.toString());}reader.close();} catch (IOException e) {e.printStackTrace();} finally {try {fs.close();fs = null;} catch (IOException e) {e.printStackTrace();}}}/*** Sequence target csv file.* @param labelIndex* @param hasHeader*/public void parse(int labelIndex, boolean hasHeader) {CSVReader reader = null;try {fs = FileSystem.getLocal(conf);if(fs.exists(this.seqPath))fs.delete(this.seqPath, true);writer = SequenceFile.createWriter(fs, conf, this.seqPath, Text.class, VectorWritable.class);reader = new CSVReader(new FileReader(this.csvPath));String[] header = null;if(hasHeader) header = reader.readNext();String[] line = null;Long l = 0L;while((line = reader.readNext()) != null) {if(labelIndex > line.length) break;l++;List<String> tmpList = Lists.newArrayList(line);String label = tmpList.get(labelIndex);if(!strLabelList.contains(label)) strLabelList.add(label);
// Text key = new Text("/" + label + "/" + l);Text key = new Text("/" + label + "/");tmpList.remove(labelIndex);VectorWritable vectorWritable = new VectorWritable();Vector vector = new RandomAccessSparseVector(tmpList.size(), tmpList.size());//???for(int i = 0; i < tmpList.size(); i++) {String tmpStr = tmpList.get(i);if(StringUtils.isNumeric(tmpStr))vector.set(i, Double.parseDouble(tmpStr));else vector.set(i, parseStrCell(tmpStr)); }vectorWritable.set(vector);writer.append(key, vectorWritable);}} catch (IOException e) {e.printStackTrace();} finally {try {fs.close();fs = null;writer.close();reader.close();} catch (IOException e) {e.printStackTrace();}}}private Long parseStrCell(String str) {Long id = word2LongMap.get(str); if( id == null) {id = (long) (word2LongMap.size() + 1);word2LongMap.put(str, id);} return id;}
}
说明一下这个代码的工作流程:
1. 初始化hadoop,比如Configuration 、 FileSystem。
2. 通过Hadoop的 Sequence.Writer进行sequence文件的写入。其中的key/value 分别是Text 跟VectorWritable类型。
3. 通过CSVReader读入CSV文件,然后逐行遍历。如果是带标题的,则先略过第一行。
4. 对于每一行,将Array转成List方便操作。将label列从list之中删除~
5. 对于sequencefile, key为label + row number, 并且,需要以"/"作为开头,否则在实际运行的时候会提示找不到key!
6. 对于sequencefile的value,使用一个Vector进行数据承载。在此使用的是RandomAccessSparseVector,可以试着使用DenseVector进行测试,看看是否在性能上会有所改善。
在用Bayes试过了好几种数据之后,感觉对于Bayes,最关键的一步其实是在这里,因为选择那些feature、原始数据如何预处理就在这里进行了,剩下的都是模板一样的代码~ 即使命令行也一样。
第三步: 训练Bayes
在这里仅仅先贴出训练部分的代码,整体的代码最后上传
public static void train() throws Throwable {System.out.println("~~~ begin to train ~~~");Configuration conf = new Configuration();FileSystem fs = FileSystem.getLocal(conf);TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();trainNaiveBayes.setConf(conf);String outputDirectory = "/home/hadoop/DataSet/kdd99/bayes/output";String tempDirectory = "/home/hadoop/DataSet/kdd99/bayes/temp";fs.delete(new Path(outputDirectory),true);fs.delete(new Path(tempDirectory),true);// cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -ctrainNaiveBayes.run(new String[] { "--input", trainSeqFile, "--output", outputDirectory,"-el", "--labelIndex", "labelIndex","--overwrite", "--tempDir", tempDirectory });// Train the classifiernaiveBayesModel = NaiveBayesModel.materialize(new Path(outputDirectory), conf);System.out.println("features: " + naiveBayesModel.numFeatures());System.out.println("labels: " + naiveBayesModel.numLabels());}
从上面的代码可以看到,熟悉命令行之后,在实际java代码编写的时候,传入进去的也是一些命令行参数。
(可能有其他方法,只是目前我还不了解~)
命令行:
// cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -c
Java代码:
trainNaiveBayes.run
最后一步: 使用测试数据进行性能验证。
public static void test() throws IOException {System.out.println("~~~ begin to test ~~~");AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel);CSVReader csv = new CSVReader(new FileReader(testFile));csv.readNext(); // skip headerString[] line = null;double totalSampleCount = 0.;double correctClsCount = 0.;while((line = csv.readNext()) != null) {totalSampleCount ++;Vector vector = new RandomAccessSparseVector(40,40);//???for(int i = 0; i < 40; i++) {if(StringUtils.isNumeric(line[i])) {vector.set(i, Double.parseDouble(line[i]));} else {Long id = strOptionMap.get(line[i]);if(id != null)vector.set(i, id);else {System.out.println(StringUtils.join(line, ","));continue;}}}Vector resultVector = classifier.classifyFull(vector);int classifyResult = resultVector.maxValueIndex();if(StringUtils.equals(line[41], strLabelList.get(classifyResult))) {correctClsCount++;} else {System.out.println("Correct=" + line[41] + "\tClassify=" + strLabelList.get(classifyResult) );}}System.out.println("Correct Ratio:" + (correctClsCount / totalSampleCount)); }
可以看到上面的加粗部分,用的是ComplementaryNaiveBayesClassifier,另外一个贝叶斯分类器就是
StandardNaiveBayesClassifier
最后运算的结果不太好,仅有约63%的正确率~
大家可以参考下面使用Bayes对Tweet进行分类的例子,正确率能有98%这样!当然,需要各位有过功夫网的本领了~
PS: 全部java代码已经在附件之中,感兴趣的还请自取~