上篇文章里面虽然结合hadoop用到mapreduce去计算属性的增益率,但是发现整个程序似乎也并没有做到并行化处理。后面又看了一些网上的资料,自己又想了想,然后又重新实现了一下决策树,大体思路如下:
1、将一个大数据集文件拆分成N个小数据集文件,对数据做好预处理工作,上传到HDFS
2、计算HDFS上小数据集文件的最佳分割属性与分割点
3、汇总N个小数据集文件的最佳划分,投票选出最佳划分
4、N个小数据集的节点根据最终的最佳划分,分割自己节点上的数据,上传到HDFS,跳转到第二步
?
下面是具体的实现代码:
public class DecisionTreeSprintBJob extends AbstractJob { private Map<String, Map<Object, Integer>> attributeValueStatistics = null; private Map<String, Set<String>> attributeNameToValues = null; private Set<String> allAttributes = null; /** 数据拆分,大数据文件拆分为小数据文件,便于分配到各个节点开启Job*/ private List<String> split(String input, String splitNum) { String output = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID(); String[] args = new String[]{input, output, splitNum}; DataFileSplitMR.main(args); List<String> inputs = new ArrayList<String>(); Path outputPath = new Path(output); try { FileSystem fs = outputPath.getFileSystem(conf); Path[] paths = HDFSUtils.getPathFiles(fs, outputPath); for(Path path : paths) { System.out.println("split input path: " + path); InputStream in = fs.open(path); BufferedReader reader = new BufferedReader(new InputStreamReader(in)); String line = reader.readLine(); while (null != line && !"".equals(line)) { inputs.add(line); line = reader.readLine(); } IOUtils.closeQuietly(in); IOUtils.closeQuietly(reader); } } catch (IOException e) { e.printStackTrace(); } System.out.println("inputs size: " + inputs.size()); return inputs; } /** 初始化工作,主要是获取特征属性集以及属性值的统计,主要是为了填充默认值*/ private void initialize(String input) { System.out.println("initialize start."); allAttributes = new HashSet<String>(); attributeNameToValues = new HashMap<String, Set<String>>(); attributeValueStatistics = new HashMap<String, Map<Object, Integer>>(); String output = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID(); String[] args = new String[]{input, output}; AttributeStatisticsMR.main(args); Path outputPath = new Path(output); SequenceFile.Reader reader = null; try { FileSystem fs = outputPath.getFileSystem(conf); Path[] paths = HDFSUtils.getPathFiles(fs, outputPath); for(Path path : paths) { reader = new SequenceFile.Reader(fs, path, conf); AttributeKVWritable key = (AttributeKVWritable) ReflectionUtils.newInstance(reader.getKeyClass(), conf); IntWritable value = new IntWritable(); while (reader.next(key, value)) { String attributeName = key.getAttributeName(); allAttributes.add(attributeName); Set<String> values = attributeNameToValues.get(attributeName); if (null == values) { values = new HashSet<String>(); attributeNameToValues.put(attributeName, values); } String attributeValue = key.getAttributeValue(); values.add(attributeValue); Map<Object, Integer> valueStatistics = attributeValueStatistics.get(attributeName); if (null == valueStatistics) { valueStatistics = new HashMap<Object, Integer>(); attributeValueStatistics.put(attributeName, valueStatistics); } valueStatistics.put(attributeValue, value.get()); value = new IntWritable(); } } } catch (IOException e) { e.printStackTrace(); } finally { IOUtils.closeQuietly(reader); } System.out.println("initialize end."); } /** 预处理,主要是将分割后的小文件填充好默认值后在上传到HDFS上面*/ private List<String> preHandle(List<String> inputs) throws IOException { List<String> fillInputs = new ArrayList<String>(); for (String input : inputs) { Data data =null; try { Path inputPath = new Path(input); FileSystem fs = inputPath.getFileSystem(conf); FSDataInputStream fsInputStream = fs.open(inputPath); data = DataLoader.load(fsInputStream, true); } catch (IOException e) { e.printStackTrace(); } DataHandler.computeFill(data.getInstances(), allAttributes.toArray(new String[0]), attributeValueStatistics, 1.0); OutputStream out = null; BufferedWriter writer = null; String outputDir = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID(); fillInputs.add(outputDir); String output = outputDir + File.separator + IdentityUtils.generateUUID(); try { Path outputPath = new Path(output); FileSystem fs = outputPath.getFileSystem(conf); out = fs.create(outputPath); writer = new BufferedWriter(new OutputStreamWriter(out)); StringBuilder sb = null; for (Instance instance : data.getInstances()) { sb = new StringBuilder(); sb.append(instance.getId()).append("\t"); sb.append(instance.getCategory()).append("\t"); Map<String, Object> attrs = instance.getAttributes(); for (Map.Entry<String, Object> entry : attrs.entrySet()) { sb.append(entry.getKey()).append(":"); sb.append(entry.getValue()).append("\t"); } writer.write(sb.toString()); writer.newLine(); } writer.flush(); } catch (Exception e) { e.printStackTrace(); } finally { IOUtils.closeQuietly(out); IOUtils.closeQuietly(writer); } } return fillInputs; } /** 创建JOB*/ private Job createJob(String jobName, String input, String output) { Configuration conf = new Configuration(); conf.set("mapred.job.queue.name", "q_hudong"); Job job = null; try { job = new Job(conf, jobName); FileInputFormat.addInputPath(job, new Path(input)); FileOutputFormat.setOutputPath(job, new Path(output)); job.setJarByClass(DecisionTreeSprintBJob.class); job.setMapperClass(CalculateGiniMapper.class); job.setMapOutputKeyClass(Text.class); job.setMapOutputValueClass(AttributeWritable.class); job.setReducerClass(CalculateGiniReducer.class); job.setOutputKeyClass(Text.class); job.setOutputValueClass(AttributeGiniWritable.class); job.setInputFormatClass(TextInputFormat.class); job.setOutputFormatClass(SequenceFileOutputFormat.class); } catch (IOException e) { e.printStackTrace(); } return job; } /** 根据HDFS上的输出路径选择最佳属性*/ private AttributeGiniWritable chooseBestAttribute(String... outputs) { AttributeGiniWritable minSplitAttribute = null; double minSplitPointGini = 1.0; try { for (String output : outputs) { System.out.println("choose output: " + output); Path outputPath = new Path(output); FileSystem fs = outputPath.getFileSystem(conf); Path[] paths = HDFSUtils.getPathFiles(fs, outputPath); ShowUtils.print(paths); SequenceFile.Reader reader = null; for (Path path : paths) { reader = new SequenceFile.Reader(fs, path, conf); Text key = (Text) ReflectionUtils.newInstance( reader.getKeyClass(), conf); AttributeGiniWritable value = new AttributeGiniWritable(); while (reader.next(key, value)) { double gini = value.getGini(); System.out.println(value.getAttribute() + " : " + gini); if (gini <= minSplitPointGini) { minSplitPointGini = gini; minSplitAttribute = value; } value = new AttributeGiniWritable(); } IOUtils.closeQuietly(reader); } System.out.println("delete hdfs file start: " + outputPath.toString()); HDFSUtils.delete(conf, outputPath); System.out.println("delete hdfs file end: " + outputPath.toString()); } } catch (IOException e) { e.printStackTrace(); } if (null == minSplitAttribute) { System.out.println("minSplitAttribute is null"); } return minSplitAttribute; } private Data obtainData(String input) { Data data = null; Path inputPath = new Path(input); try { FileSystem fs = inputPath.getFileSystem(conf); Path[] hdfsPaths = HDFSUtils.getPathFiles(fs, inputPath); FSDataInputStream fsInputStream = fs.open(hdfsPaths[0]); data = DataLoader.load(fsInputStream, true); } catch (IOException e) { e.printStackTrace(); } return data; } /** 构建决策树*/ private Object build(List<String> inputs) throws IOException { List<String> outputs = new ArrayList<String>(); JobControl jobControl = new JobControl("CalculateGini"); for (String input : inputs) { System.out.println("split path: " + input); String output = HDFSUtils.HDFS_TEMP_OUTPUT_URL + IdentityUtils.generateUUID(); outputs.add(output); Configuration conf = new Configuration(); ControlledJob controlledJob = new ControlledJob(conf); controlledJob.setJob(createJob(input, input, output)); jobControl.addJob(controlledJob); } Thread jcThread = new Thread(jobControl); jcThread.start(); while(true){ if(jobControl.allFinished()){ // System.out.println(jobControl.getSuccessfulJobList()); jobControl.stop(); AttributeGiniWritable bestAttr = chooseBestAttribute( outputs.toArray(new String[0])); String attribute = bestAttr.getAttribute(); System.out.println("best attribute: " + attribute); System.out.println("isCategory: " + bestAttr.isCategory()); if (bestAttr.isCategory()) { return attribute; } TreeNode treeNode = new TreeNode(attribute); Map<String, List<String>> splitToInputs = new HashMap<String, List<String>>(); for (String input : inputs) { Data data = obtainData(input); String splitPoint = bestAttr.getSplitPoint();// Map<String, Set<String>> attrName2Values = // DataHandler.attributeValueStatistics(data.getInstances()); Set<String> attributeValues = attributeNameToValues.get(attribute); System.out.println("attributeValues:"); ShowUtils.print(attributeValues); if (attributeNameToValues.size() == 0 || null == attributeValues) { continue; } attributeValues.remove(splitPoint); StringBuilder sb = new StringBuilder(); for (String attributeValue : attributeValues) { sb.append(attributeValue).append(","); } if (sb.length() > 0) sb.deleteCharAt(sb.length() - 1); String[] names = new String[]{splitPoint, sb.toString()}; DataSplit dataSplit = DataHandler.split(new Data( data.getInstances(), attribute, names)); for (DataSplitItem item : dataSplit.getItems()) { if (item.getInstances().size() == 0) continue; String path = item.getPath(); String name = path.substring(path.lastIndexOf(File.separator) + 1); String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name; HDFSUtils.copyFromLocalFile(conf, path, hdfsPath); String split = item.getSplitPoint(); List<String> nextInputs = splitToInputs.get(split); if (null == nextInputs) { nextInputs = new ArrayList<String>(); splitToInputs.put(split, nextInputs); } nextInputs.add(hdfsPath); } } for (Map.Entry<String, List<String>> entry : splitToInputs.entrySet()) { treeNode.setChild(entry.getKey(), build(entry.getValue())); } return treeNode; } if(jobControl.getFailedJobList().size() > 0){ // System.out.println(jobControl.getFailedJobList()); jobControl.stop(); } } } /** 分类样本集*/ private void classify(TreeNode treeNode, String testSet, String output) { OutputStream out = null; BufferedWriter writer = null; try { Path testSetPath = new Path(testSet); FileSystem testFS = testSetPath.getFileSystem(conf); Path[] testHdfsPaths = HDFSUtils.getPathFiles(testFS, testSetPath); FSDataInputStream fsInputStream = testFS.open(testHdfsPaths[0]); Data testData = DataLoader.load(fsInputStream, true); DataHandler.computeFill(testData.getInstances(), allAttributes.toArray(new String[0]), attributeValueStatistics, 1.0); Object[] results = (Object[]) treeNode.classifySprint(testData); ShowUtils.print(results); DataError dataError = new DataError(testData.getCategories(), results); dataError.report(); String path = FileUtils.obtainRandomTxtPath(); out = new FileOutputStream(new File(path)); writer = new BufferedWriter(new OutputStreamWriter(out)); StringBuilder sb = null; for (int i = 0, len = results.length; i < len; i++) { sb = new StringBuilder(); sb.append(i+1).append("\t").append(results[i]); writer.write(sb.toString()); writer.newLine(); } writer.flush(); Path outputPath = new Path(output); FileSystem fs = outputPath.getFileSystem(conf); if (!fs.exists(outputPath)) { fs.mkdirs(outputPath); } String name = path.substring(path.lastIndexOf(File.separator) + 1); HDFSUtils.copyFromLocalFile(conf, path, output + File.separator + name); } catch (IOException e) { e.printStackTrace(); } finally { IOUtils.closeQuietly(out); IOUtils.closeQuietly(writer); } } public void run(String[] args) { try { if (null == conf) conf = new Configuration(); String[] inputArgs = new GenericOptionsParser( conf, args).getRemainingArgs(); if (inputArgs.length != 4) { System.out.println("error, please input three path."); System.out.println("1. trainset path."); System.out.println("2. testset path."); System.out.println("3. result output path."); System.out.println("4. data split number."); System.exit(2); } List<String> splitInputs = split(inputArgs[0], inputArgs[3]); initialize(inputArgs[0]); List<String> inputs = preHandle(splitInputs); TreeNode treeNode = (TreeNode) build(inputs); TreeNodeHelper.print(treeNode, 0, null); classify(treeNode, inputArgs[1], inputArgs[2]); } catch (Exception e) { e.printStackTrace(); } } public static void main(String[] args) { DecisionTreeSprintBJob job = new DecisionTreeSprintBJob(); long startTime = System.currentTimeMillis(); job.run(args); long endTime = System.currentTimeMillis(); System.out.println("spend time: " + (endTime - startTime)); }}
?
?
?