ID3分类决策树算法
程序员文章站
2024-02-15 09:40:16
...
简述
对已知D中元组分类所需要的平均信息公式:
1. 平均信息的解释:
对信息所使用的进制的解释:
由此对于公式(1),我们以2为底,表示我们以2进制对信息进行编码,Info(D)表示我们对D中元组进行全部分类时,以2进制为编码表示这些信息所需要的位数。
属性的选择度量
- 根据不同的度量方式我们可以选择不同的度量方法,如使用信息增益作为属性选择度量的ID3算法,使用增益率作为属性选择度量的C4.5算法,使用基尼指数作为属性选择度量的CART算法。这几个算法都是使用不同的属性度量的决策树算法。
ID3算法
- ID3算法使用信息增益作为属性的选择度量
- 信息增益
Info(D)=−∑mi=1Pilog2(Pi) (1)
按属性A进行划分后的新的信息需求为:InfoA(D)=∑vj=1(|Dj|/|D|)∗Info(Dj) (2)
信息增益:Gain(A)=Info(D)−InfoA(D)
总结:
信息增益告诉我们通过A上的划分我们得到了多少信息。
决策树
sex | color | size |
---|---|---|
m | red | s |
m | blue | m |
f | blue | m |
f | yellow | b |
算法思想
先用GetDecisionTreeDFS函数利用训练数据训练出决策树,在对测试数据代进决策树进行测试,从而对他们进行分类。
DataSet类:
package ID3;
import java.util.ArrayList;
public class DataSet {
public ArrayList<String> attrSet;
public ArrayList<ArrayList<String>> dataRows;
protected String targetAttribute;
public DataSet()
{
this.attrSet = null;
this.targetAttribute = null;
this.dataRows = new ArrayList<ArrayList<String>>();
}
public DataSet(ArrayList<String> attrSet, String targetAttribute)
{
this.attrSet = new ArrayList<String>();
this.attrSet = attrSet;
this.targetAttribute = targetAttribute;
this.dataRows = new ArrayList<ArrayList<String>>();
}
public void AddRow(ArrayList<String> row)
{
dataRows.add(row);
}
}
Node类:
package ID3;
import java.util.ArrayList;
public class Node {
public String attrName;//属性名
public ArrayList<String> rules;//属性规则
public ArrayList<Node> children;//子节点集合
public String targetValue;//目标属性值,只有叶子结点才有的
public Node(String attrName, ArrayList<String> rules)//树枝节点
{
this.attrName = attrName;
this.rules = rules;
this.children = new ArrayList<Node>();
this.targetValue = targetValue;
}
public Node(String attrName, String targetValue)//构建叶子节点
{
this.attrName = attrName;
this.rules = rules;
this.children = new ArrayList<Node>();
this.targetValue = targetValue;
}
/**
* 递归遍历打印树结构
* @param root 根节点
* @param spaceCount 缩进的空格数
* @param rules 父节点规则,即树枝
*/
public void PrintTree(Node root, int spaceCount, String rules)
{
if(root == null)
{
return;
}
for(int i = 0; i < spaceCount; i++)
{
System.out.println(" ");
}
if(root.targetValue != null)
{
System.out.println((rules != null ? rules+":":"") + root.targetValue + "(leaf)");
}
else
{
System.out.println((rules != null ? rules+":":"") + root.attrName);
}
if(root.children != null && root.children.size() > 0)
{
for(int i = 0; i< root.children.size(); i++)
{
PrintTree(root.children.get(i), spaceCount+2, root.rules.get(i));
}
}
}
public String Test(String... datas)
{
if(datas.length != ID3.originalDataSet.attrSet.size())
{
System.out.println("数据有误,不完整");
return "";
}
Node node = this;
while(node != null)
{
if(node.targetValue != null)
return node.targetValue;
String attrName = this.attrName;
int columnIndex = ID3.originalDataSet.attrSet.indexOf(attrName);
boolean testRight = false;
for(String rule : node.rules)
{
if(rule.equals(datas[columnIndex]))
{
node = node.children.get(node.rules.indexOf(rule));
testRight = true;
break;
}
}
if(!testRight)
break;
}
return null;
}
}
ID3类:
package ID3;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
public class ID3 {
public static DataSet originalDataSet;//最初的数据集
public ID3(DataSet dataset)
{
originalDataSet = dataset;
}
/**
* 利用数据集训练出来一个决策树
* @param dataSet 数据集
* @return 决策树根节点
*/
public Node GetDecisionTreeDFS(DataSet dataSet)
{
if (dataSet.dataRows == null)
return null;
//剩下的都是同一类的,则返回一个叶子节点
/*for(ArrayList<String> row : dataSet.dataRows)
{
System.out.println(row);
}*/
if(TargetAttrIsAllSame(dataSet))
return new Node(dataSet.targetAttribute, dataSet.dataRows.get(0).get(originalDataSet.attrSet.size()));
//如果属性集为空,则将其归类为数据集中目标属性值最多的哪一个目标属性值
if(dataSet.attrSet.size() <= 0)
{
return new Node(dataSet.targetAttribute, this.GetMajorTargetValue(dataSet));
}
//寻找最大的Gain值属性
String maxGainAttrName = null;
double maxGain = -1;
ArrayList<String> rules = new ArrayList<String>();
for(String attrName : dataSet.attrSet)
{
ArrayList<String> tempRules = this.GetAttrRules(dataSet, attrName);
double gain = GetGain(dataSet, attrName, tempRules);
if(maxGain < gain)
{
maxGain = gain;
maxGainAttrName = attrName;
rules.clear();
rules.addAll(tempRules);
}
}
Node node = new Node(maxGainAttrName, rules);//生成一个新节点
for(int i = 0; i < node.rules.size(); i++)
{
ArrayList<String> newAttrSet = new ArrayList<String>();
for(String attr : dataSet.attrSet)
{
if(attr != maxGainAttrName)
{
newAttrSet.add(attr);
}
}
//获取新的数据集
DataSet newDataSet = FindSpecificDT(dataSet, maxGainAttrName, node.rules.get(i));
newDataSet.attrSet = newAttrSet;
//递归再继续分类
node.children.add(GetDecisionTreeDFS(newDataSet));
}
return node;
}
/**
* 获取属性为attr,而且属性值为value对应的数据集
* @param dataSet 新生成的数据集
* @param attr 特定属性
* @param value 特定的属性值
* @return 返回数据集
*/
public DataSet FindSpecificDT(DataSet dataSet, String attr, String value)
{
DataSet resultSet = new DataSet(null, originalDataSet.targetAttribute);
int columIndex = originalDataSet.attrSet.indexOf(attr);
for(ArrayList<String> row : dataSet.dataRows)
{
if(value.equals(row.get(columIndex)))
{
resultSet.AddRow(row);
}
}
return resultSet;
}
/**
* 找出分裂的相应属性rules
* @param dataSet 数据集
* @param attrName 属性名
* @return
*/
public ArrayList<String> GetAttrRules(DataSet dataSet, String attrName)
{
ArrayList<String> result = new ArrayList<String>();
int columIndex = dataSet.attrSet.indexOf(attrName);
for(ArrayList<String> row : dataSet.dataRows)
{
String value = row.get(columIndex);
if(!result.contains(value))
result.add(value);
}
return result;
}
/**
* 返回信息增益
* @param dataSet 数据集
* @param attrName 属性名
* @param rules
* @return
*/
public double GetGain(DataSet dataSet, String attrName, ArrayList<String> rules)
{
return getEntropy(dataSet, null, null) - getEntropy(dataSet, attrName, rules);
}
/**
* 计算熵
* @param dataSet 数据集
* @param attrName 属性名
* @param rules 分裂准则
* @return
*/
public double getEntropy(DataSet dataSet, String attrName, ArrayList<String> rules)
{
if(attrName == null)
{
Map<String, Integer> map = GetEachTargetValue(dataSet);
return CalculateEntropy(map);
}
else{
double result = 0.0;
for(int i = 0; i < rules.size(); i++)
{
Map<String, Integer> map = GetEachTargetValue(dataSet, attrName, rules.get(i));
double entroy = CalculateEntropy(map);
double sum = 0.0;
for(Entry<String, Integer> entry : map.entrySet())
{
sum += entry.getValue();
}
double dtSize = dataSet.dataRows.size();
result += (double)(sum/dtSize)*entroy;
}
return result;
}
}
/**
* 计算制定属性的属性值的数量
* @param dataSet 数据集
* @param attrName 属性名
* @param rules 属性值
* @return
*/
public Map<String, Integer> GetEachTargetValue(DataSet dataSet, String attrName, String value)
{
Map<String, Integer> map = new HashMap<String, Integer>();
int columIndex = dataSet.attrSet.indexOf(attrName);
for(int i = 0; i < dataSet.dataRows.size(); i++)
{
String targetValue = dataSet.dataRows.get(i).get(originalDataSet.attrSet.size());
if(value.equals(dataSet.dataRows.get(i).get(columIndex)))
{
if(map.containsKey(targetValue))
{
map.put(targetValue, map.get(targetValue)+1);
}
else
{
map.put(targetValue, 1);
}
}
}
return map;
}
/**
* 计算熵
* @param map <目标属性值,个数>
* @return 熵
*/
public double CalculateEntropy(Map<String, Integer> map)
{
double sum = 0.0;
for(Entry<String, Integer> entry : map.entrySet())
{
sum += entry.getValue();
}
double result = 0.0;
for(Entry<String, Integer> entry : map.entrySet())
{
int value = entry.getValue();
if(value == 0)
continue;
result += -((double)value/sum)*(Math.log((double)value/sum)/Math.log(2.0));
}
return result;
}
/**
* 判断数据集里的属性是否是同一个类
* @param dataSet 数据集
* @return 返回结果
*/
public boolean TargetAttrIsAllSame(DataSet dataSet)
{
String tempValue = null;
for(ArrayList<String> row : dataSet.dataRows)
{
String value = row.get(originalDataSet.attrSet.size());
if(tempValue == null)
{
tempValue = value;
continue;
}
if(!tempValue.equals(value))
{
return false;
}
}
return true;
}
/**
* 根据
* @param dataSet 数据集
* @return 返回
*/
public Map<String, Integer> GetEachTargetValue(DataSet dataSet)
{
Map<String, Integer> map = new HashMap<String, Integer>();
for(int i = 0; i < dataSet.dataRows.size(); i++)
{
String name = dataSet.dataRows.get(i).get(dataSet.attrSet.size());
if(map.containsKey(name))
{
map.put(name, map.get(name)+1);
}
else
{
map.put(name, 1);
}
}
return map;
}
/**
* 找出目标属性值数最多的属性值
* @param dataSet 数据集
* @return
*/
public String GetMajorTargetValue(DataSet dataSet)
{
String maxTargetValue = null;
int maxCount = -1;
Map<String, Integer> map = this.GetEachTargetValue(dataSet);
for(Entry<String, Integer> entry : map.entrySet())
{
if(entry.getValue() > maxCount)
maxTargetValue = entry.getKey();
}
return maxTargetValue;
}
}
Main类:
package ID3;
import java.util.ArrayList;
import java.util.Map;
import java.util.Map.Entry;
public class Main {
public static void main(String[] args) {
// TODO Auto-generated method stub
ArrayList<String> l1 = new ArrayList<String>();
l1.add("m");
l1.add("red");
l1.add("s");
ArrayList<String> l2 = new ArrayList<String>();
l2.add("m");
l2.add("blue");
l2.add("m");
ArrayList<String> l3 = new ArrayList<String>();
l3.add("f");
l3.add("blue");
l3.add("m");
ArrayList<String> l4 = new ArrayList<String>();
l4.add("f");
l4.add("yellow");
l4.add("b");
ArrayList<String> l5 = new ArrayList<String>();
l5.add("m");
l5.add("blue");
l5.add("s");
ArrayList<ArrayList<String>> l = new ArrayList<ArrayList<String>>();
l.add(l1);
l.add(l2);
l.add(l3);
l.add(l4);
l.add(l5);
ArrayList<String> attrSet = new ArrayList<String>();
attrSet.add("sex");
attrSet.add("color");
DataSet dataSet = new DataSet(attrSet, "size");
dataSet.AddRow(l1);
dataSet.AddRow(l2);
dataSet.AddRow(l3);
dataSet.AddRow(l4);
dataSet.AddRow(l5);
ID3 ID = new ID3(dataSet);//生成决策树
Node node = ID.GetDecisionTreeDFS(dataSet);//返回决策树根节点
node.PrintTree(node, 0, null);
String[] datas = {"m", "blue"};//测试数据
System.out.println("reult: " + node.Test(datas));
DataSet dataset = ID.FindSpecificDT(dataSet, "color", "blue");
}
}
上一篇: PriorityQueue
下一篇: python3 || 决策树 ID3算法