weka开发21ibkknn源代码分析.docx
《weka开发21ibkknn源代码分析.docx》由会员分享,可在线阅读,更多相关《weka开发21ibkknn源代码分析.docx(9页珍藏版)》请在冰点文库上搜索。
![weka开发21ibkknn源代码分析.docx](https://file1.bingdoc.com/fileroot1/2023-6/19/84940cb4-b1bb-4200-8c37-32467b022a1f/84940cb4-b1bb-4200-8c37-32467b022a1f1.gif)
weka开发21ibkknn源代码分析
Weka开发[21]——IBk(KNN)源代码分析
如果你没有看上一篇IB1,请先看一下,因为重复的内容我在这里不会介绍了。
直接看buildClassifier,这里只列出在IB1中也没有出现的代码:
try{
m_NumClasses=instances.numClasses();
m_ClassType=instances.classAttribute().type();
}catch(Exceptionex){
thrownewError("Thisshouldneverbereached");
}
//Throwawayinitialinstancesuntilwithinthespecifiedwindowsize
if((m_WindowSize>0)&&(instances.numInstances()>m_WindowSize)){
m_Train=newInstances(m_Train,m_Train.numInstances()
-m_WindowSize,m_WindowSize);
}
//Computethenumberofattributesthatcontribute
//toeachprediction
m_NumAttributesUsed=0.0;
for(inti=0;i<m_Train.numAttributes();i){
if((i!
=m_Train.classIndex())
&&(m_Train.attribute(i).isNominal()||m_Train
.attribute(i).isNumeric())){
m_NumAttributesUsed=1.0;
}
}
//Invalidateanycurrentlycross-validationselectedk
m_kNNValid=false;
IB1中不关心m_NumClasses是因为它就找一个邻居,当然就一个值了。
m_WindowSize是指用多少样本用于分类,这里不是随机选择而是直接选前m_WindowSize个。
这里下面是看有多少属性参与预测。
KNN也是一个可以增量学习的分器量,下面看一下它的updateClassifier代码:
publicvoidupdateClassifier(Instanceinstance)throwsException{
if(m_Train.equalHeaders(instance.dataset())==false){
thrownewException("Incompatibleinstancetypes");
}
if(instance.classIsMissing()){
return;
}
if(!
m_DontNormalize){
updateMinMax(instance);
}
m_Train.add(instance);
m_kNNValid=false;
if((m_WindowSize>0)&&(m_Train.numInstances()>m_WindowSize)){
while(m_Train.numInstances()>m_WindowSize){
m_Train.delete(0);
}
}
}
同样很简单,updateMinMax,如果超出窗口大小,循环删除超过窗口大小的第一个样本。
这里注意IBk没有实现classifyInstance,它只实现了distributionForInstances:
publicdouble[]distributionForInstance(Instanceinstance)throwsException{
if(m_Train.numInstances()==0){
thrownewException("Notraininginstances!
");
}
if((m_WindowSize>0)&&(m_Train.numInstances()>m_WindowSize)){
m_kNNValid=false;
booleandeletedInstance=false;
while(m_Train.numInstances()>m_WindowSize){
m_Train.delete(0);
}
//rebuilddatastructureKDTreecurrentlycan'tdelete
if(deletedInstance==true)
m_NNSearch.setInstances(m_Train);
}
//Selectkbycrossvalidation
if(!
m_kNNValid&&(m_CrossValidate)&&(m_kNNUpper>=1)){
crossValidate();
}
m_NNSearch.addInstanceInfo(instance);
Instancesneighbours=m_NNSearch.kNearestNeighbours(instance,
m_kNN);
double[]distances=m_NNSearch.getDistances();
double[]distribution=makeDistribution(neighbours,distances);
returndistribution;
}
前面两个判断不讲了,crossValidate()马上讲,寻找K个邻居在我第[18]篇里已经讲过了,现在我们看一下makeDistribution函数。
protecteddouble[]makeDistribution(Instancesneighbours,
double[]distances)throwsException{
doubletotal=0,weight;
double[]distribution=newdouble[m_NumClasses];
//Setupacorrectiontotheestimator
if(m_ClassType==Attribute.NOMINAL){
for(inti=0;i<m_NumClasses;i){
distribution[i]=1.0/Math.max(1,m_Train.numInstances());
}
total=(double)m_NumClasses/Math.max(1,
m_Train.numInstances());
}
for(inti=0;i<neighbours.numInstances();i){
//Collectclasscounts
Instancecurrent=neighbours.instance(i);
distances[i]=distances[i]*distances[i];
distances[i]=Math.sqrt(distances[i]/m_NumAttributesUsed);
switch(m_DistanceWeighting){
caseWEIGHT_INVERSE:
weight=1.0/(distances[i]0.001);//toavoiddivbyzero
break;
caseWEIGHT_SIMILARITY:
weight=1.0-distances[i];
break;
default:
//WEIGHT_NONE:
weight=1.0;
break;
}
weight*=current.weight();
try{
switch(m_ClassType){
caseAttribute.NOMINAL:
distribution[(int)current.classValue()]=weight;
break;
caseAttribute.NUMERIC:
distribution[0]=current.classValue()*weight;
break;
}
}catch(Exceptionex){
thrownewError("Datahasnoclassattribute!
");
}
total=weight;
}
//Normalisedistribution
if(total>0){
Utils.normalize(distribution,total);
}
returndistribution;
}
第一行注释Setupacorrection,我感觉没什么必要,又不是Bayes还有除0错误,没什么可修正的。
这里可以看见它实现了三种距离权重计算方法,倒数,与1的差,另外就是固定权重1。
然后如果类别是离散值把对应的类值加上权重,如果是连续值,就加上当前类别值剩权重。
crossValidate简单地说就是用蛮力找在到底用多少个邻居好,它对m_Train中的样本进行循环,对每个样本找邻居,然后统计看寻找多少个邻居时最好。
protectedvoidcrossValidate(){
double[]performanceStats=newdouble[m_kNNUpper];
double[]performanceStatsSq=newdouble[m_kNNUpper];
for(inti=0;i<m_kNNUpper;i){
performanceStats[i]=0;
performanceStatsSq[i]=0;
}
m_kNN=m_kNNUpper;
Instanceinstance;
Instancesneighbours;
double[]origDistances,convertedDistances;
for(inti=0;i<m_Train.numInstances();i){
instance=m_Train.instance(i);
neighbours=m_NNSearch.kNearestNeighbours(instance,m_kNN);
origDistances=m_NNSearch.getDistances();
for(intj=m_kNNUpper-1;j>=0;j--){
//Updatetheperformancestats
convertedDistances=newdouble[origDistances.length];
System.arraycopy(origDistances,0,convertedDistances,0,
origDistances.length);
double[]distribution=makeDistribution(neighbours,
convertedDistances);
doublethisPrediction=Utils.maxIndex(distribution);
if(m_Train.classAttribute().isNumeric()){
thisPrediction=distribution[0];
doubleerr=thisPrediction-instance.classValue();
performanceStatsSq[j]=err*err;//Squarederror
performanceStats[j]=Math.abs(err);//Absoluteerror
}else{
if(thisPrediction!
=instance.classValue()){
performanceStats[j];//Classificationerror
}
}
if(j>=1){
neighbours=pruneToK(neighbours,
convertedDistances,j);
}
}
}
//Checkthroughtheperformancestatsandselectthebest
//kvalue(orthelowestkifmorethanonebest)
double[]searchStats=performanceStats;
if(m_Train.classAttribute().isNumeric()&&m_MeanSquared){
searchStats=performanceStatsSq;
}
doublebestPerformance=Double.NaN;
intbestK=1;
for(inti=0;i<m_kNNUpper;i){
if(Double.isNaN(bestPerformance)
||(bestPerformance>searchStats[i])){
bestPerformance=searchStats[i];
bestK=i1;
}
}
m_kNN=bestK;
m_kNNValid=true;
}
m_kNNUpper是另一个设置最多有多少样本的参数,枚举每一个样本(instance),找它的邻居(neighbors),和距离(origDistances)。
接下来就是把从0到m_kNNUpper个邻居的得到的方差(performanceStatsSq)和标准差(performanceStats)与以前得到的值累加。
pruneToK就是得到j个样本(如果j1的距离不等于第j个),后面就比较好理解了,m_MeanSquared对连续类别是选择用方差还是标准差进行选择,然后最出m_kNNUpper看在多少邻居的时候,分类误差最小,就认为是最好的邻居数。