欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

java实现决策树算法

程序员文章站 2022-03-22 17:40:46
...

决策树

package decisiontree;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

public class DecisionTree {

	public static Map<String, Item> train(String[][] trainData) {
		Map<String, Item> model = new HashMap<String, Item>();
		List<String[]> trainingDataList = new ArrayList<String[]>();

		for (int i = 0; i < trainData.length; i++) {
			trainingDataList.add(trainData[i]);
		}

		Set<Integer> handledSet = new HashSet<Integer>();

		train(model, Item.ROOT_EKY, handledSet, trainingDataList);

		return model;
	}

	private static void train(Map<String, Item> model, String currentKey, Set<Integer> handledSet, List<String[]> trainingDataList) {
		double entropyValue = getEntropyValue(trainingDataList);

		if (Math.abs(entropyValue) < Double.MIN_VALUE) {
			// all are the same value

			Item modelItem = new Item(currentKey, trainingDataList.get(0)[trainingDataList.get(0).length - 1]);
			model.put(modelItem.key, modelItem);

		} else {
			// not the only value
			double minEntropyValue = Double.MAX_VALUE;
			Map<String, List<String[]>> minEntropySplitDataMap = null;
			int minEntropyAttrIndex = -1;
			for (int i = 0; i <= trainingDataList.get(0).length - 2; i++) {
				if (!handledSet.contains(i)) {
					Map<String, List<String[]>> splitData = getSplitData(trainingDataList, i);
					entropyValue = getTotalEntropyValue(splitData);
					if (entropyValue < minEntropyValue) {
						minEntropySplitDataMap = splitData;
						minEntropyAttrIndex = i;
						minEntropyValue = entropyValue;
					}
				}
			}

			handledSet.add(minEntropyAttrIndex);

			if (minEntropySplitDataMap.size() == 1) {
				// there is only value in result list, skip this attribute;
				train(model, currentKey, handledSet, trainingDataList);
			} else {
				// there are more than one attribute value
				Item modelItem = new Item(currentKey, null);
				modelItem.currentIndex = minEntropyAttrIndex;
				model.put(modelItem.key, modelItem);

				for (String attrKey : minEntropySplitDataMap.keySet()) {
					String subKey = getKey(currentKey, minEntropyAttrIndex, attrKey);
					train(model, subKey, handledSet, minEntropySplitDataMap.get(attrKey));
				}

			}

			handledSet.remove(minEntropyAttrIndex);

		}
	}

	private static String getKey(String parentKey, int attrIndex, String value) {
		String key = "";

		if (parentKey == null || parentKey.trim().length() == 0) {
			key = String.valueOf(attrIndex) + "-" + value;
		} else {
			key = parentKey + "-" + String.valueOf(attrIndex) + "-" + value;
		}

		return key;
	}

	private static double getTotalEntropyValue(Map<String, List<String[]>> splitData) {
		double rtn = 0;

		for (List<String[]> itemList : splitData.values()) {
			rtn += getEntropyValue(itemList);
		}

		return rtn;
	}

	private static double getEntropyValue(List<String[]> splitData) {
		double rtn = 0;
		Map<String, AtomicInteger> countMap = new HashMap<String, AtomicInteger>();

		for (String[] itemData : splitData) {
			String value = itemData[itemData.length - 1];
			if (!countMap.containsKey(value)) {
				countMap.put(value, new AtomicInteger(0));
			}

			countMap.get(value).getAndIncrement();
		}

		for (AtomicInteger count : countMap.values()) {
			double probability = 1.0d * count.get() / splitData.size();
			rtn -= probability * Math.log(probability) / Math.log(2.0);
		}

		return rtn;
	}

	private static Map<String, List<String[]>> getSplitData(List<String[]> data, int i) {
		Map<String, List<String[]>> rtn = new HashMap<String, List<String[]>>();

		for (String[] itemData : data) {
			String value = itemData[i];
			List<String[]> itemDataList = rtn.get(value);
			if (itemDataList == null) {
				itemDataList = new ArrayList<String[]>();
				rtn.put(value, itemDataList);
			}
			itemDataList.add(itemData);
		}

		return rtn;
	}

	public static void saveModel(String fileName, Map<String, Item> model) {
		try (BufferedWriter writer = new BufferedWriter(new FileWriter(fileName));) {
			for (Item item : model.values()) {
				writer.write(item.toStr());
				writer.write("\n");
			}

		} catch (Exception e) {
			System.out.println("save Model error");
		}
	}

	public static Map<String, Item> loadModel(String fileName) {
		Map<String, Item> model = new HashMap<String, Item>();
		try (BufferedReader reader = new BufferedReader(new FileReader(fileName));) {

			String lineStr = reader.readLine();
			while (lineStr != null) {
				if (lineStr.trim().length() > 0) {
					String[] itemStr = (lineStr + ",^^").split(",");
					if (itemStr.length == 4) {
						Item itme = new Item(itemStr[1], itemStr[2]);
						if (itemStr[0] != null && itemStr[0].trim().length() > 0) {
							itme.currentIndex = Integer.valueOf(itemStr[0]);
						}

						model.put(itme.key, itme);
					} else {
						System.out.println("Error model line:" + lineStr);
					}
				}

				lineStr = reader.readLine();
			}

		} catch (Exception e) {
			System.out.println("load model error");
		}

		return model;
	}

	public static String getValue(Map<String, Item> model, String[] fieldValues) {
		String rtn = null;

		if (model != null && model.size() > 0) {
			rtn = getValueFromModel(model, Item.ROOT_EKY, fieldValues);
		}

		return rtn;
	}

	private static String getValueFromModel(Map<String, Item> model, String key, String[] fieldValues) {
		String rtn = null;

		Item item = model.get(key);
		if (item != null) {
			if (item.value != null && item.value.trim().length() > 0) {
				return item.value;
			} else {
				String fieldValue = fieldValues[item.currentIndex];
				String fieldIndex = String.valueOf(item.currentIndex);
				String currentKey = null;
				if (key != null && key.trim().length() > 0) {
					currentKey = key + Item.KEY_SEPARATOR + fieldIndex + Item.KEY_SEPARATOR + fieldValue;
				} else {
					currentKey = fieldIndex + Item.KEY_SEPARATOR + fieldValue;
				}

				rtn = getValueFromModel(model, currentKey, fieldValues);
			}

		}

		return rtn;
	}

}

 

    item

public class Item {
	public static final String KEY_SEPARATOR = "-";
	public static final String STR_SEPARATOR = ",";
	public static final String ROOT_EKY = "";

	public String parentKey = null;

	public String key = "";

	public String value = null;

	public int currentIndex = -1;

	public Item(String key, String value) {
		super();
		this.key = key;
		this.value = value;
	}

	public String toStr() {
		StringBuilder b = new StringBuilder();
		if (currentIndex >= 0) {
			b.append(String.valueOf(currentIndex));
		}
		b.append(STR_SEPARATOR);

		if (key != null) {
			b.append(key);
		}
		b.append(STR_SEPARATOR);
		if (value != null) {
			b.append(value);
		}

		return b.toString();
	}

}

 

简单的训练数据(最后一列为目标属性)

A,C,A

A,D,A

A,A,A

B,C,B

C,C,C

 

序列化的模型(下个属性列序号, Key:列序号1-列属性1-列序号2-列属性2,目标属性)

0,,

,0-A,A

,0-B,B

,0-C,C

 

相关标签: 机器学习