决策树归纳(ID3属性选择度量)Java实现

一般的决策树归纳框架见之前的博文:http://blog.csdn.net/zhyoulun/article/details/41978381


ID3属性选择度量原理

ID3使用信息增益作为属性选择度量。该度量基于香农在研究消息的值或”信息内容“的信息论方面的先驱工作。该结点N代表或存放分区D的元组。选择具有最高信息增益的属性作为结点N的分裂属性。该属性使结果分区中对元祖分类所需要的信息量最小,并反映这些分区中的最小随机性或”不纯性“。这种方法使得对一个对象分类所需要的期望测试数目最小,并确保找到一颗简单的(但不必是最简单的)树。

对D中的元组分类所需要的期望信息由下式给出,

技术分享

其中pi是D忠任意元组属于类Ci的非零概率。使用以2为底的对数函数是因为信息用二进制编码。Info(D)是识别D中元组的类标号所需要的平均信息量。注意,此时我们所有的信息只是每个类的元组所占的百分比。

现在假设我们要按照某属性A划分D中的元组,其中属性A根据训练数据的观测具有v个不同的值{a1,a2,...av}。可以用属性A将D划分为v个分区或子集{D1,D2,...,Dv},其中Dj包含D中的元组,它们的A值为aj。这些分区对应于从节点N生长出来的分支。理想情况下,我们希望该划分产生元组的准确分类。即希望每个分区都是纯的(实际情况多半是不纯的,如分区可能包含来自不同类的元组)。在此划分之后,为了得到准确的分类,我们还需要多少信息?这个量由下式度量:

技术分享

其中|Dj|/|D|充当第j个分区的权重。Info_A(D)是基于按A划分对D的元组分类所需要的期望值信息需要的期望信息越小,分区的纯度越高

信息增益定义为原来的信息需求(仅基于类比例)与新的信息需求(对A划分后)之前的差。即

技术分享

换言之,Gain(A)告诉我们通过A上的划分我们得到了多少。它是知道A的值而导致的信息需求的期望减少。选择具有最高信息增益Gain(A)的属性A作为结点N的分裂属性。


以下为例子。


数据

data.txt

youth,high,no,fair,no
youth,high,no,excellent,no
middle_aged,high,no,fair,yes
senior,medium,no,fair,yes
senior,low,yes,fair,yes
senior,low,yes,excellent,no
middle_aged,low,yes,excellent,yes
youth,medium,no,fair,no
youth,low,yes,fair,yes
senior,medium,yes,fair,yes
youth,medium,yes,excellent,yes
middle_aged,medium,no,excellent,yes
middle_aged,high,yes,fair,yes
senior,medium,no,excellent,no


attr.txt

age,income,student,credit_rating,buys_computer


运算结果

age(1:youth; 2:middle_aged; 3:senior; )
	credit_rating(1:fair; 2:excellent; )
		leaf:no()
		leaf:yes()
	leaf:yes()
	student(1:no; 2:yes; )
		leaf:no()
		leaf:yes()


技术分享


最后附上java代码

DecisionTree.java

package com.zhyoulun.decision;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Map;

/**
 * 负责数据的读入和写出,以及生成决策树
 * 
 * @author zhyoulun
 *
 */
public class DecisionTree
{
	private ArrayList<ArrayList<String>> allDatas;
	private ArrayList<String> allAttributes;
	
	/**
	 * 从文件中读取所有相关数据
	 * @param dataFilePath
	 * @param attrFilePath
	 */
	public DecisionTree(String dataFilePath,String attrFilePath)
	{
		super();
		
		try
		{
			this.allDatas = new ArrayList<>();
			this.allAttributes = new ArrayList<>();
			
			InputStreamReader inputStreamReader = new InputStreamReader(new FileInputStream(new File(dataFilePath)));
			BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
			String line = null;
			while((line=bufferedReader.readLine())!=null)
			{
				String[] strings = line.split(",");
				ArrayList<String> data = new ArrayList<>();
				for(int i=0;i<strings.length;i++)
					data.add(strings[i]);
				this.allDatas.add(data);
			}
			
			
			inputStreamReader = new InputStreamReader(new FileInputStream(new File(attrFilePath)));
			bufferedReader = new BufferedReader(inputStreamReader);
			while((line=bufferedReader.readLine())!=null)
			{
				String[] strings = line.split(",");
				for(int i=0;i<strings.length;i++)
					this.allAttributes.add(strings[i]);
			}
			
			inputStreamReader.close();
			bufferedReader.close();
			
		} catch (FileNotFoundException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		
		
//		for(int i=0;i<this.allAttributes.size();i++)
//		{
//			System.out.print(this.allAttributes.get(i)+" ");
//		}
//		System.out.println();
//		
//		for(int i=0;i<this.allDatas.size();i++)
//		{
//			for(int j=0;j<this.allDatas.get(i).size();j++)
//			{
//				System.out.print(this.allDatas.get(i).get(j)+" ");
//			}
//			System.out.println();
//		}
		
	}
	
	/**
	 * @param allDatas
	 * @param allAttributes
	 */
	public DecisionTree(ArrayList<ArrayList<String>> allDatas,
			ArrayList<String> allAttributes)
	{
		super();
		this.allDatas = allDatas;
		this.allAttributes = allAttributes;
	}
	
	public ArrayList<ArrayList<String>> getAllDatas()
	{
		return allDatas;
	}

	public void setAllDatas(ArrayList<ArrayList<String>> allDatas)
	{
		this.allDatas = allDatas;
	}

	public ArrayList<String> getAllAttributes()
	{
		return allAttributes;
	}

	public void setAllAttributes(ArrayList<String> allAttributes)
	{
		this.allAttributes = allAttributes;
	}

	
	
	/**
	 * 递归生成决策数
	 * @return
	 */
	public static TreeNode generateDecisionTree(ArrayList<ArrayList<String>> datas, ArrayList<String> attrs)
	{
		TreeNode treeNode = new TreeNode();
		
		//如果D中的元素都在同一类C中,then
		if(isInTheSameClass(datas))
		{
			treeNode.setName(datas.get(0).get(datas.get(0).size()-1));
//			rootNode.setName();
			return treeNode;
		}
		//如果attrs为空,then(这种情况一般不会出现,我们应该是要对所有的候选属性集合构建决策树)
		if(attrs.size()==0)
			return treeNode;
		
		CriterionID3 criterionID3 = new CriterionID3(datas, attrs);
		int splitingCriterionIndex = criterionID3.attributeSelectionMethod();
		
		treeNode.setName(attrs.get(splitingCriterionIndex));
		treeNode.setRules(getValueSet(datas, splitingCriterionIndex));
		
		attrs.remove(splitingCriterionIndex);
		
		Map<String, ArrayList<ArrayList<String>>> subDatasMap = criterionID3.getSubDatasMap(splitingCriterionIndex);
//		for(String key:subDatasMap.keySet())
//		{
//			System.out.println("===========");
//			System.out.println(key);
//			for(int i=0;i<subDatasMap.get(key).size();i++)
//			{
//				for(int j=0;j<subDatasMap.get(key).get(i).size();j++)
//				{
//					System.out.print(subDatasMap.get(key).get(i).get(j)+" ");
//				}
//				System.out.println();
//			}
//		}
		
		for(String key:subDatasMap.keySet())
		{
			ArrayList<TreeNode> treeNodes = treeNode.getChildren();
			treeNodes.add(generateDecisionTree(subDatasMap.get(key), attrs));
			treeNode.setChildren(treeNodes);
		}
		
		return treeNode;
	}
	
	
	
	
	/**
	 * 获取datas中index列的值域
	 * @param data
	 * @param index
	 * @return
	 */
	public static ArrayList<String> getValueSet(ArrayList<ArrayList<String>> datas,int index)
	{
		ArrayList<String> values = new ArrayList<String>();
		String r = "";
		for (int i = 0; i < datas.size(); i++) {
			r = datas.get(i).get(index);
			if (!values.contains(r)) {
				values.add(r);
			}
		}
		return values;
	}
	
	/**
	 * 最后一列是类标号,判断最后一列是否相同
	 * @param datas
	 * @return
	 */
	private static boolean isInTheSameClass(ArrayList<ArrayList<String>> datas)
	{
		String flag = datas.get(0).get(datas.get(0).size()-1);//第0行,最后一列赋初值
		for(int i=0;i<datas.size();i++)
		{
			if(!datas.get(i).get(datas.get(i).size()-1).equals(flag))
				return false;
		}
		return true;
	}


	public static void main(String[] args)
	{
		String dataPath = "files/data.txt";
		String attrPath = "files/attr.txt";
		
		//初始化原始数据
		DecisionTree decisionTree = new DecisionTree(dataPath,attrPath);
		
		//生成决策树
		TreeNode treeNode = generateDecisionTree(decisionTree.getAllDatas(), decisionTree.getAllAttributes());
		
		print(treeNode,0);
	}
	
	private static void print(TreeNode treeNode,int level)
	{
		for(int i=0;i<level;i++)
			System.out.print("\t");
		System.out.print(treeNode.getName());
		System.out.print("(");
		for(int i=0;i<treeNode.getRules().size();i++)
			System.out.print((i+1)+":"+treeNode.getRules().get(i)+"; ");
		System.out.println(")");
		
		ArrayList<TreeNode> treeNodes = treeNode.getChildren();
		for(int i=0;i<treeNodes.size();i++)
		{
			print(treeNodes.get(i),level+1);
		}
	}
	
	
}



CriterionID3.java

package com.zhyoulun.decision;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * ID3,选择分裂准则
 * 
 * @author zhyoulun
 *
 */
public class CriterionID3
{
	private ArrayList<ArrayList<String>> datas;
	private ArrayList<String> attributes;
	
	private Map<String, ArrayList<ArrayList<String>>> subDatasMap;
	
	/**
	 * 计算所有的信息增益,获取最大的一项作为分裂属性
	 * @return
	 */
	public int attributeSelectionMethod()
	{
		double gain = -1.0;
		int maxIndex = 0;
		for(int i=0;i<this.attributes.size()-1;i++)
		{
			double tempGain = this.calcGain(i);
			if(tempGain>gain)
			{
				gain = tempGain;
				maxIndex = i;
			}
		}
		
		return maxIndex;
	}
	
	/**
	 * 计算 Gain(age)=Info(D)-Info_age(D) 等
	 * @param index
	 * @return
	 */
	/**
	 * @param index
	 * @param isCalcSubDatasMap
	 * @return
	 */
	private double calcGain(int index)
	{
		double result = 0;
		
		//计算Info(D)
		int lastIndex = datas.get(0).size()-1;
		ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas,lastIndex);
		for(String value:valueSet)
		{
			int count = 0;
			for(int i=0;i<datas.size();i++)
			{
				if(datas.get(i).get(lastIndex).equals(value))
					count++;
			}
			
			result += -(1.0*count/datas.size())*Math.log(1.0*count/datas.size())/Math.log(2);
//			System.out.println(result);
		}
//		System.out.println("==========");
		
		//计算Info_a(D)
		valueSet = DecisionTree.getValueSet(this.datas,index);
		
//		for(String temp:valueSet)
//			System.out.println(temp);
//		System.out.println("==========");
		
		for(String value:valueSet)
		{	
			ArrayList<ArrayList<String>> subDatas = new ArrayList<>();
			for(int i=0;i<datas.size();i++)
			{
				if(datas.get(i).get(index).equals(value))
					subDatas.add(datas.get(i));
			}
			
			
			
			
//			for(ArrayList<String> temp:subDatas)
//			{
//				for(String temp2:temp)
//					System.out.print(temp2+" ");
//				System.out.println();
//			}
			
			ArrayList<String> subValueSet = DecisionTree.getValueSet(subDatas, lastIndex);
			
			
//			System.out.print("subValueSet:");
//			for(String temp:subValueSet)
//				System.out.print(temp+" ");
//			System.out.println();
			
			
			for(String subValue:subValueSet)
			{
//				System.out.println("+++++++++++++++");
//				System.out.println(subValue);
				int count = 0;
				for(int i=0;i<subDatas.size();i++)
				{
					if(subDatas.get(i).get(lastIndex).equals(subValue))
						count++;
				}
//				System.out.println(count);
				result += -1.0*subDatas.size()/datas.size()*(-(1.0*count/subDatas.size())*Math.log(1.0*count/subDatas.size())/Math.log(2));
//				System.out.println(result);
			}
			
		}
		
		return result;
		
	}



	public CriterionID3(ArrayList<ArrayList<String>> datas,
			ArrayList<String> attributes)
	{
		super();
		this.datas = datas;
		this.attributes = attributes;
	}



	public ArrayList<ArrayList<String>> getDatas()
	{
		return datas;
	}



	public void setDatas(ArrayList<ArrayList<String>> datas)
	{
		this.datas = datas;
	}



	public ArrayList<String> getAttributes()
	{
		return attributes;
	}



	public void setAttributes(ArrayList<String> attributes)
	{
		this.attributes = attributes;
	}

	public Map<String, ArrayList<ArrayList<String>>> getSubDatasMap(int index)
	{
		ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas, index);
		this.subDatasMap = new HashMap<String, ArrayList<ArrayList<String>>>();
		
		for(String value:valueSet)
		{
			ArrayList<ArrayList<String>> subDatas = new ArrayList<>();
			for(int i=0;i<this.datas.size();i++)
			{
				if(this.datas.get(i).get(index).equals(value))
					subDatas.add(this.datas.get(i));
			}
			for(int i=0;i<subDatas.size();i++)
			{
				subDatas.get(i).remove(index);
			}
			this.subDatasMap.put(value, subDatas);
		}
		
		return subDatasMap;
	}

	public void setSubDatasMap(Map<String, ArrayList<ArrayList<String>>> subDatasMap)
	{
		this.subDatasMap = subDatasMap;
	}
	
	
}



TreeNode.java

package com.zhyoulun.decision;

import java.util.ArrayList;

public class TreeNode
{
	private String name; 								// 该结点的名称(分裂属性)
	private ArrayList<String> rules; 				// 结点的分裂规则(假设均为离散值)
//	private ArrayList<ArrayList<String>> datas; 	// 划分到该结点的训练元组(datas.get(i)表示一个训练元组)
//	private ArrayList<String> candidateAttributes; // 划分到该结点的候选属性(与训练元组的个数一致)
	private ArrayList<TreeNode> children; 			// 子结点

	public TreeNode()
	{
		this.name = "";
		this.rules = new ArrayList<String>();
		this.children = new ArrayList<TreeNode>();
//		this.datas = null;
//		this.candidateAttributes = null;
	}

	public String getName()
	{
		return name;
	}

	public void setName(String name)
	{
		this.name = name;
	}

	public ArrayList<String> getRules()
	{
		return rules;
	}

	public void setRules(ArrayList<String> rules)
	{
		this.rules = rules;
	}

	public ArrayList<TreeNode> getChildren()
	{
		return children;
	}

	public void setChildren(ArrayList<TreeNode> children)
	{
		this.children = children;
	}

//	public ArrayList<ArrayList<String>> getDatas()
//	{
//		return datas;
//	}
//
//	public void setDatas(ArrayList<ArrayList<String>> datas)
//	{
//		this.datas = datas;
//	}
//
//	public ArrayList<String> getCandidateAttributes()
//	{
//		return candidateAttributes;
//	}
//
//	public void setCandidateAttributes(ArrayList<String> candidateAttributes)
//	{
//		this.candidateAttributes = candidateAttributes;
//	}
	
	
}



参考:《数据挖掘概念与技术(第3版)》

转载请注明出处:

郑重声明:本站内容如果来自互联网及其他传播媒体,其版权均属原媒体及文章作者所有。转载目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。