决策树程序实验Word文档格式.docx
《决策树程序实验Word文档格式.docx》由会员分享,可在线阅读,更多相关《决策树程序实验Word文档格式.docx(30页珍藏版)》请在冰点文库上搜索。
![决策树程序实验Word文档格式.docx](https://file1.bingdoc.com/fileroot1/2023-5/8/0e0ae677-c46a-411a-b9a5-511cc6fb2026/0e0ae677-c46a-411a-b9a5-511cc6fb20261.gif)
(b)没有剩余属性可以用来进一步划分样本(步骤4).在此情况下,采用多数表决(步骤5)。
这涉及将给定的结点转换成树叶,并用samples中的多数所在类别标记它。
换一种方式,可以存放结点样本的类分布。
(c)分支test_attribute=ai没有样本。
在这种情况下,以samples中的多数类创建一个树叶(步骤12)。
算法Decision_Tree(samples,attribute_list)
输入由离散值属性描述的训练样本集samples;
候选属性集合attribute_list.
输出一棵决策树。
(1)创建节点N;
(2)Ifsamples都在同一类C中then
(3)返回N作为叶节点,以类C标记;
(4)Ifattribute_list为空then
(5)返回N作为叶节点,以samples中最普遍的类标记;
//多数表决
(6)选择attribute_list中具有最高信息增益的属性test_attribute;
(7)以test_attribute标记节点N;
(8)Foreachtest_attribute的已知值v//划分samples
(9)由节点N分出一个对应test_attribute=v的分支;
(10)令Sv为samples中test_attribute=v的样本集合;
//一个划分块
(11)IfSv为空then
(12)加上一个叶节点,以samples中最普遍的类标记;
(13)Else加入一个由Decision_Tree(Sv,attribute_list—test_attribute)返回节点值
E(S)=(—9\15)log2(9\15)-(6\15)log2(6\15)=0.971
Values(收入范围)={20—30K,30—40k,40—50K,50-60K}
E(S(20-30K))=(-2\4)log2(2\4)-(2\4)log2(2\4)=1
E(S(30-40K))=(-4\5)log2(4\5)-(1\5)log2(1\5)=0。
7219
E(S(40—50K))=(—1\4)log2(1\4)-(3\4)log2(3\4)=0.8113
E(S(50—60K))=(—2\2)log2(2\2)—(0\2)log2(0\2)=0
所以
E(S,收入范围)=(4/15)E(S(20-30K))+(5/15)E(S(30—40K))+(4/15)E(S(40—50K))+(2/15)E(S(50-60K))=0。
7236
Gain(S,收入范围)=0。
971-0。
7236=0.2474
同理:
计算“保险"
,“性别”,“年龄”的信息增益为:
E(S)=(—9\15)log2(9\15)-(6\15)log2(6\15)=0。
971
Insurance(保险)={yes,no}
E(S(yes))=(—3\3)log2(3\3)-(0\3)log2(0\3)=0
E(S(no))=(-6\12)log2(6\12)-(6\12)log2(6\12)=1
E(S,保险)=(3/15)E(S(yes))+(12/15)E(S(no))=0.8
Gain(S,保险)=0。
971—0。
8=0.171
E(S)=(-9\15)log2(9\15)—(6\15)log2(6\15)=0。
sex(性别)={male,female}
E(S(male))=(—3\7)log2(3\7)—(4\7)log2(4\7)=0。
9852
E(S(female))=(—6\8)log2(6\8)-(2\8)log2(2\8)=0。
8113
E(S,性别)=(7/15)E(S(male))+(8/15)E(S(female))=0。
8925
Gain(S,性别)=0。
971-0.8925=0。
0785
E(S)=(-9\15)log2(9\15)-(6\15)log2(6\15)=0。
age(年龄)={15~40,41~60}
E(S(15~40))=(-6\7)log2(6\7)—(1\7)log2(1\7)=0。
5917
E(S(41~60))=(-3\8)log2(3\8)—(5\8)log2(5\8)=0。
9544
E(S,年龄)=(7/15)E(S(15~40))+(8/15)E(S(41~60))=0.7851
Gain(S,年龄)=0.971—0。
7851=0.1859
代码
packageDecisionTree;
importjava.util。
ArrayList;
/**
*决策树结点类
*/
publicclassTreeNode{
privateStringname;
//节点名(分裂属性的名称)
privateArrayList〈String〉rule;
//结点的分裂规则
ArrayList〈TreeNode>
child;
//子结点集合
privateArrayList〈ArrayList<
String>
>
datas;
//划分到该结点的训练元组
privateArrayList<
candAttr;
//划分到该结点的候选属性
publicTreeNode(){
this。
name="
"
;
this.rule=newArrayList〈String>
();
child=newArrayList<
TreeNode〉();
this.datas=null;
candAttr=null;
}
publicArrayList〈TreeNode>
getChild(){
returnchild;
}
publicvoidsetChild(ArrayList〈TreeNode〉child){
this.child=child;
publicArrayList〈String>
getRule(){
returnrule;
publicvoidsetRule(ArrayList<
String〉rule){
rule=rule;
publicStringgetName(){
returnname;
publicvoidsetName(Stringname){
this.name=name;
publicArrayList〈ArrayList〈String〉>
getDatas(){
returndatas;
publicvoidsetDatas(ArrayList〈ArrayList<
String〉>
datas){
datas=datas;
publicArrayList〈String〉getCandAttr(){
returncandAttr;
publicvoidsetCandAttr(ArrayList〈String>
candAttr){
candAttr=candAttr;
}
importjava。
io。
BufferedReader;
importjava.io.IOException;
InputStreamReader;
util.ArrayList;
StringTokenizer;
/**
*决策树算法测试类
publicclassTestDecisionTree{
/**
*读取候选属性
*@return候选属性集合
*@throwsIOException
publicArrayList<
String〉readCandAttr()throwsIOException{
ArrayList<
String〉candAttr=newArrayList〈String〉();
BufferedReaderreader=newBufferedReader(newInputStreamReader(System。
in));
Stringstr=””;
while(!
(str=reader。
readLine())。
equals(”"
)){
StringTokenizertokenizer=newStringTokenizer(str);
while(tokenizer。
hasMoreTokens()){
candAttr.add(tokenizer。
nextToken());
}
}
returncandAttr;
/**
*读取训练元组
*@return训练元组集合
*@throwsIOException
*/
ArrayList<
String〉〉readData()throwsIOException{
datas=newArrayList〈ArrayList<
();
in));
Stringstr="
;
while(!
(str=reader.readLine()).equals(”"
)){
ArrayList〈String>
s=newArrayList<
String〉();
s。
add(tokenizer。
}
datas.add(s);
/**
*递归打印树结构
*@paramroot当前待输出信息的结点
publicvoidprintTree(TreeNoderoot){
System。
out.println(”name:
”+root.getName());
ArrayList〈String〉rules=root。
getRule();
out。
print("
noderules:
{"
);
for(inti=0;
i<
rules.size();
i++){
System。
out.print(rules。
get(i)+"
"
);
print(”}"
out.println("
”);
TreeNode>
children=root。
getChild();
intsize=children。
size();
if(size==0){
println(”——〉leafnode!
<
—-”);
}else{
sizeofchildren:
”+children。
size());
for(inti=0;
children。
System。
out.print(”child”+(i+1)+"
ofnode”+root。
getName()+”:
printTree(children。
get(i));
*主函数,程序入口
*@paramargs
publicstaticvoidmain(String[]args){
TestDecisionTreetdt=newTestDecisionTree();
candAttr=null;
ArrayList〈ArrayList<
datas=null;
try{
System.out.println("
请输入候选属性”);
candAttr=tdt。
readCandAttr();
请输入训练数据"
datas=tdt.readData();
}catch(IOExceptione){
e。
printStackTrace();
DecisionTreetree=newDecisionTree();
TreeNoderoot=tree.buildTree(datas,candAttr);
tdt。
printTree(root);
importjava.util.ArrayList;
util.HashMap;
Iterator;
importjava.util.Map;
/**
*选择最佳分裂属性
publicclassGain{
String〉〉D=null;
//训练元组
privateArrayList〈String〉attrList=null;
//候选属性集
publicGain(ArrayList〈ArrayList<
datas,ArrayList〈String>
attrList){
this.D=datas;
attrList=attrList;
/**
*获取最佳侯选属性列上的值域(假定所有属性列上的值都是有限的名词或分类类型的)
*@paramattrIndex指定的属性列的索引
*@return值域集合
String〉getValues(ArrayList〈ArrayList〈String〉>
datas,intattrIndex){
ArrayList〈String>
values=newArrayList<
Stringr=”"
datas.size();
i++){
r=datas。
get(i)。
get(attrIndex);
if(!
values。
contains(r)){
values。
add(r);
}
returnvalues;
*获取指定数据集中指定属性列索引的域值及其计数
*@paramd指定的数据集
*@paramattrIndex指定的属性列索引
*@return类别及其计数的map
publicMap〈String,Integer>
valueCounts(ArrayList〈ArrayList<
String〉〉datas,intattrIndex){
Map〈String,Integer>
valueCount=newHashMap〈String,Integer〉();
Stringc=””;
ArrayList〈String〉tuple=null;
for(inti=0;
i〈datas.size();
tuple=datas.get(i);
c=tuple。
get(attrIndex);
if(valueCount。
containsKey(c)){
valueCount.put(c,valueCount。
get(c)+1);
}else{
valueCount。
put(c,1);
returnvalueCount;
*求对datas中元组分类所需的期望信息,即datas的熵
*@paramdatas训练元组
*@returndatas的熵值
publicdoubleinfoD(ArrayList〈ArrayList〈String>
〉datas){
doubleinfo=0.000;
inttotal=datas。
size();
classes=valueCounts(datas,attrList.size());
Iteratoriter=classes.entrySet().iterator();
Integer[]counts=newInteger[classes。
size()];
for(inti=0;
iter。
hasNext();
i++)
{
Map.Entryentry=(Map.Entry)iter.next();
Integerval=(Integer)entry.getValue();
counts[i]=val;
i〈counts。
length;
doublebase=DecimalCalculate.div(counts[i],total,3);
info+=(-1)*base*Math。
log(base);
returninfo;
*获取指定属性列上指定值域的所有元组
*@paramattrIndex指定属性列索引
*@paramvalue指定属性列的值域
*@return指定属性列上指定值域的所有元组
ArrayList〈String〉〉datasOfValue(intattrIndex,Stringvalue){
String〉〉Di=newArrayList<
String〉〉();
ArrayList〈String〉t=null;
i〈D.size();
t=D。
get(i);
if(t。
get(attrIndex)。
equals(value)){
Di。
add(t);
returnDi;
*基于按指定属性划分对D的元组分类所需要的期望信息
*@paramattrIndex指定属性的索引
*@return按指定属性划分的期望信息值
publicdoubleinfoAttr(intattrIndex){
doubleinfo=0.000;
ArrayList〈String〉values=getValues(D,attrIndex);
for(inti