weka开发21ibkknn源代码分析.docx

上传人:b****4 文档编号:13902810 上传时间:2023-06-19 格式:DOCX 页数:9 大小:16.86KB
下载 相关 举报
weka开发21ibkknn源代码分析.docx_第1页
第1页 / 共9页
weka开发21ibkknn源代码分析.docx_第2页
第2页 / 共9页
weka开发21ibkknn源代码分析.docx_第3页
第3页 / 共9页
weka开发21ibkknn源代码分析.docx_第4页
第4页 / 共9页
weka开发21ibkknn源代码分析.docx_第5页
第5页 / 共9页
weka开发21ibkknn源代码分析.docx_第6页
第6页 / 共9页
weka开发21ibkknn源代码分析.docx_第7页
第7页 / 共9页
weka开发21ibkknn源代码分析.docx_第8页
第8页 / 共9页
weka开发21ibkknn源代码分析.docx_第9页
第9页 / 共9页
亲,该文档总共9页,全部预览完了,如果喜欢就下载吧!
下载资源
资源描述

weka开发21ibkknn源代码分析.docx

《weka开发21ibkknn源代码分析.docx》由会员分享,可在线阅读,更多相关《weka开发21ibkknn源代码分析.docx(9页珍藏版)》请在冰点文库上搜索。

weka开发21ibkknn源代码分析.docx

weka开发21ibkknn源代码分析

Weka开发[21]——IBk(KNN)源代码分析

如果你没有看上一篇IB1,请先看一下,因为重复的内容我在这里不会介绍了。

直接看buildClassifier,这里只列出在IB1中也没有出现的代码:

try{

m_NumClasses=instances.numClasses();

m_ClassType=instances.classAttribute().type();

}catch(Exceptionex){

thrownewError("Thisshouldneverbereached");

}

//Throwawayinitialinstancesuntilwithinthespecifiedwindowsize

if((m_WindowSize>0)&&(instances.numInstances()>m_WindowSize)){

m_Train=newInstances(m_Train,m_Train.numInstances()

-m_WindowSize,m_WindowSize);

}

//Computethenumberofattributesthatcontribute

//toeachprediction

m_NumAttributesUsed=0.0;

for(inti=0;i<m_Train.numAttributes();i){

if((i!

=m_Train.classIndex())

&&(m_Train.attribute(i).isNominal()||m_Train

.attribute(i).isNumeric())){

m_NumAttributesUsed=1.0;

}

}

//Invalidateanycurrentlycross-validationselectedk

m_kNNValid=false;

IB1中不关心m_NumClasses是因为它就找一个邻居,当然就一个值了。

m_WindowSize是指用多少样本用于分类,这里不是随机选择而是直接选前m_WindowSize个。

这里下面是看有多少属性参与预测。

KNN也是一个可以增量学习的分器量,下面看一下它的updateClassifier代码:

publicvoidupdateClassifier(Instanceinstance)throwsException{

if(m_Train.equalHeaders(instance.dataset())==false){

thrownewException("Incompatibleinstancetypes");

}

if(instance.classIsMissing()){

return;

}

if(!

m_DontNormalize){

updateMinMax(instance);

}

m_Train.add(instance);

m_kNNValid=false;

if((m_WindowSize>0)&&(m_Train.numInstances()>m_WindowSize)){

while(m_Train.numInstances()>m_WindowSize){

m_Train.delete(0);

}

}

}

同样很简单,updateMinMax,如果超出窗口大小,循环删除超过窗口大小的第一个样本。

这里注意IBk没有实现classifyInstance,它只实现了distributionForInstances:

publicdouble[]distributionForInstance(Instanceinstance)throwsException{

if(m_Train.numInstances()==0){

thrownewException("Notraininginstances!

");

}

if((m_WindowSize>0)&&(m_Train.numInstances()>m_WindowSize)){

m_kNNValid=false;

booleandeletedInstance=false;

while(m_Train.numInstances()>m_WindowSize){

m_Train.delete(0);

}

//rebuilddatastructureKDTreecurrentlycan'tdelete

if(deletedInstance==true)

m_NNSearch.setInstances(m_Train);

}

//Selectkbycrossvalidation

if(!

m_kNNValid&&(m_CrossValidate)&&(m_kNNUpper>=1)){

crossValidate();

}

m_NNSearch.addInstanceInfo(instance);

Instancesneighbours=m_NNSearch.kNearestNeighbours(instance,

m_kNN);

double[]distances=m_NNSearch.getDistances();

double[]distribution=makeDistribution(neighbours,distances);

returndistribution;

}

前面两个判断不讲了,crossValidate()马上讲,寻找K个邻居在我第[18]篇里已经讲过了,现在我们看一下makeDistribution函数。

protecteddouble[]makeDistribution(Instancesneighbours,

double[]distances)throwsException{

doubletotal=0,weight;

double[]distribution=newdouble[m_NumClasses];

//Setupacorrectiontotheestimator

if(m_ClassType==Attribute.NOMINAL){

for(inti=0;i<m_NumClasses;i){

distribution[i]=1.0/Math.max(1,m_Train.numInstances());

}

total=(double)m_NumClasses/Math.max(1,

m_Train.numInstances());

}

for(inti=0;i<neighbours.numInstances();i){

//Collectclasscounts

Instancecurrent=neighbours.instance(i);

distances[i]=distances[i]*distances[i];

distances[i]=Math.sqrt(distances[i]/m_NumAttributesUsed);

switch(m_DistanceWeighting){

caseWEIGHT_INVERSE:

weight=1.0/(distances[i]0.001);//toavoiddivbyzero

break;

caseWEIGHT_SIMILARITY:

weight=1.0-distances[i];

break;

default:

//WEIGHT_NONE:

weight=1.0;

break;

}

weight*=current.weight();

try{

switch(m_ClassType){

caseAttribute.NOMINAL:

distribution[(int)current.classValue()]=weight;

break;

caseAttribute.NUMERIC:

distribution[0]=current.classValue()*weight;

break;

}

}catch(Exceptionex){

thrownewError("Datahasnoclassattribute!

");

}

total=weight;

}

//Normalisedistribution

if(total>0){

Utils.normalize(distribution,total);

}

returndistribution;

}

第一行注释Setupacorrection,我感觉没什么必要,又不是Bayes还有除0错误,没什么可修正的。

这里可以看见它实现了三种距离权重计算方法,倒数,与1的差,另外就是固定权重1。

然后如果类别是离散值把对应的类值加上权重,如果是连续值,就加上当前类别值剩权重。

crossValidate简单地说就是用蛮力找在到底用多少个邻居好,它对m_Train中的样本进行循环,对每个样本找邻居,然后统计看寻找多少个邻居时最好。

protectedvoidcrossValidate(){

double[]performanceStats=newdouble[m_kNNUpper];

double[]performanceStatsSq=newdouble[m_kNNUpper];

for(inti=0;i<m_kNNUpper;i){

performanceStats[i]=0;

performanceStatsSq[i]=0;

}

m_kNN=m_kNNUpper;

Instanceinstance;

Instancesneighbours;

double[]origDistances,convertedDistances;

for(inti=0;i<m_Train.numInstances();i){

instance=m_Train.instance(i);

neighbours=m_NNSearch.kNearestNeighbours(instance,m_kNN);

origDistances=m_NNSearch.getDistances();

for(intj=m_kNNUpper-1;j>=0;j--){

//Updatetheperformancestats

convertedDistances=newdouble[origDistances.length];

System.arraycopy(origDistances,0,convertedDistances,0,

origDistances.length);

double[]distribution=makeDistribution(neighbours,

convertedDistances);

doublethisPrediction=Utils.maxIndex(distribution);

if(m_Train.classAttribute().isNumeric()){

thisPrediction=distribution[0];

doubleerr=thisPrediction-instance.classValue();

performanceStatsSq[j]=err*err;//Squarederror

performanceStats[j]=Math.abs(err);//Absoluteerror

}else{

if(thisPrediction!

=instance.classValue()){

performanceStats[j];//Classificationerror

}

}

if(j>=1){

neighbours=pruneToK(neighbours,

convertedDistances,j);

}

}

}

//Checkthroughtheperformancestatsandselectthebest

//kvalue(orthelowestkifmorethanonebest)

double[]searchStats=performanceStats;

if(m_Train.classAttribute().isNumeric()&&m_MeanSquared){

searchStats=performanceStatsSq;

}

doublebestPerformance=Double.NaN;

intbestK=1;

for(inti=0;i<m_kNNUpper;i){

if(Double.isNaN(bestPerformance)

||(bestPerformance>searchStats[i])){

bestPerformance=searchStats[i];

bestK=i1;

}

}

m_kNN=bestK;

m_kNNValid=true;

}

m_kNNUpper是另一个设置最多有多少样本的参数,枚举每一个样本(instance),找它的邻居(neighbors),和距离(origDistances)。

接下来就是把从0到m_kNNUpper个邻居的得到的方差(performanceStatsSq)和标准差(performanceStats)与以前得到的值累加。

pruneToK就是得到j个样本(如果j1的距离不等于第j个),后面就比较好理解了,m_MeanSquared对连续类别是选择用方差还是标准差进行选择,然后最出m_kNNUpper看在多少邻居的时候,分类误差最小,就认为是最好的邻居数。

展开阅读全文
相关资源
猜你喜欢
相关搜索
资源标签

当前位置:首页 > 经管营销 > 经济市场

copyright@ 2008-2023 冰点文库 网站版权所有

经营许可证编号:鄂ICP备19020893号-2