java实现决策树算法
程序员文章站
2022-03-22 17:40:46
...
决策树
package decisiontree; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; public class DecisionTree { public static Map<String, Item> train(String[][] trainData) { Map<String, Item> model = new HashMap<String, Item>(); List<String[]> trainingDataList = new ArrayList<String[]>(); for (int i = 0; i < trainData.length; i++) { trainingDataList.add(trainData[i]); } Set<Integer> handledSet = new HashSet<Integer>(); train(model, Item.ROOT_EKY, handledSet, trainingDataList); return model; } private static void train(Map<String, Item> model, String currentKey, Set<Integer> handledSet, List<String[]> trainingDataList) { double entropyValue = getEntropyValue(trainingDataList); if (Math.abs(entropyValue) < Double.MIN_VALUE) { // all are the same value Item modelItem = new Item(currentKey, trainingDataList.get(0)[trainingDataList.get(0).length - 1]); model.put(modelItem.key, modelItem); } else { // not the only value double minEntropyValue = Double.MAX_VALUE; Map<String, List<String[]>> minEntropySplitDataMap = null; int minEntropyAttrIndex = -1; for (int i = 0; i <= trainingDataList.get(0).length - 2; i++) { if (!handledSet.contains(i)) { Map<String, List<String[]>> splitData = getSplitData(trainingDataList, i); entropyValue = getTotalEntropyValue(splitData); if (entropyValue < minEntropyValue) { minEntropySplitDataMap = splitData; minEntropyAttrIndex = i; minEntropyValue = entropyValue; } } } handledSet.add(minEntropyAttrIndex); if (minEntropySplitDataMap.size() == 1) { // there is only value in result list, skip this attribute; train(model, currentKey, handledSet, trainingDataList); } else { // there are more than one attribute value Item modelItem = new Item(currentKey, null); modelItem.currentIndex = minEntropyAttrIndex; model.put(modelItem.key, modelItem); for (String attrKey : minEntropySplitDataMap.keySet()) { String subKey = getKey(currentKey, minEntropyAttrIndex, attrKey); train(model, subKey, handledSet, minEntropySplitDataMap.get(attrKey)); } } handledSet.remove(minEntropyAttrIndex); } } private static String getKey(String parentKey, int attrIndex, String value) { String key = ""; if (parentKey == null || parentKey.trim().length() == 0) { key = String.valueOf(attrIndex) + "-" + value; } else { key = parentKey + "-" + String.valueOf(attrIndex) + "-" + value; } return key; } private static double getTotalEntropyValue(Map<String, List<String[]>> splitData) { double rtn = 0; for (List<String[]> itemList : splitData.values()) { rtn += getEntropyValue(itemList); } return rtn; } private static double getEntropyValue(List<String[]> splitData) { double rtn = 0; Map<String, AtomicInteger> countMap = new HashMap<String, AtomicInteger>(); for (String[] itemData : splitData) { String value = itemData[itemData.length - 1]; if (!countMap.containsKey(value)) { countMap.put(value, new AtomicInteger(0)); } countMap.get(value).getAndIncrement(); } for (AtomicInteger count : countMap.values()) { double probability = 1.0d * count.get() / splitData.size(); rtn -= probability * Math.log(probability) / Math.log(2.0); } return rtn; } private static Map<String, List<String[]>> getSplitData(List<String[]> data, int i) { Map<String, List<String[]>> rtn = new HashMap<String, List<String[]>>(); for (String[] itemData : data) { String value = itemData[i]; List<String[]> itemDataList = rtn.get(value); if (itemDataList == null) { itemDataList = new ArrayList<String[]>(); rtn.put(value, itemDataList); } itemDataList.add(itemData); } return rtn; } public static void saveModel(String fileName, Map<String, Item> model) { try (BufferedWriter writer = new BufferedWriter(new FileWriter(fileName));) { for (Item item : model.values()) { writer.write(item.toStr()); writer.write("\n"); } } catch (Exception e) { System.out.println("save Model error"); } } public static Map<String, Item> loadModel(String fileName) { Map<String, Item> model = new HashMap<String, Item>(); try (BufferedReader reader = new BufferedReader(new FileReader(fileName));) { String lineStr = reader.readLine(); while (lineStr != null) { if (lineStr.trim().length() > 0) { String[] itemStr = (lineStr + ",^^").split(","); if (itemStr.length == 4) { Item itme = new Item(itemStr[1], itemStr[2]); if (itemStr[0] != null && itemStr[0].trim().length() > 0) { itme.currentIndex = Integer.valueOf(itemStr[0]); } model.put(itme.key, itme); } else { System.out.println("Error model line:" + lineStr); } } lineStr = reader.readLine(); } } catch (Exception e) { System.out.println("load model error"); } return model; } public static String getValue(Map<String, Item> model, String[] fieldValues) { String rtn = null; if (model != null && model.size() > 0) { rtn = getValueFromModel(model, Item.ROOT_EKY, fieldValues); } return rtn; } private static String getValueFromModel(Map<String, Item> model, String key, String[] fieldValues) { String rtn = null; Item item = model.get(key); if (item != null) { if (item.value != null && item.value.trim().length() > 0) { return item.value; } else { String fieldValue = fieldValues[item.currentIndex]; String fieldIndex = String.valueOf(item.currentIndex); String currentKey = null; if (key != null && key.trim().length() > 0) { currentKey = key + Item.KEY_SEPARATOR + fieldIndex + Item.KEY_SEPARATOR + fieldValue; } else { currentKey = fieldIndex + Item.KEY_SEPARATOR + fieldValue; } rtn = getValueFromModel(model, currentKey, fieldValues); } } return rtn; } }
item
public class Item { public static final String KEY_SEPARATOR = "-"; public static final String STR_SEPARATOR = ","; public static final String ROOT_EKY = ""; public String parentKey = null; public String key = ""; public String value = null; public int currentIndex = -1; public Item(String key, String value) { super(); this.key = key; this.value = value; } public String toStr() { StringBuilder b = new StringBuilder(); if (currentIndex >= 0) { b.append(String.valueOf(currentIndex)); } b.append(STR_SEPARATOR); if (key != null) { b.append(key); } b.append(STR_SEPARATOR); if (value != null) { b.append(value); } return b.toString(); } }
简单的训练数据(最后一列为目标属性)
A,C,A
A,D,A
A,A,A
B,C,B
C,C,C
序列化的模型(下个属性列序号, Key:列序号1-列属性1-列序号2-列属性2,目标属性)
0,,
,0-A,A
,0-B,B
,0-C,C
上一篇: 如何实现小程序的登录与授权
下一篇: java实现朴素贝叶斯算法