博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
KNN算法的實現
阅读量:6757 次
发布时间:2019-06-26

本文共 3463 字,大约阅读时间需要 11 分钟。

转自 

Knn.h

#pragma once

class Knn

{
private:
 double** trainingDataset;
 double* arithmeticMean;
 double* standardDeviation;
 int m, n;

 void RescaleDistance(double* row);

 void RescaleTrainingDataset();
 void ComputeArithmeticMean();
 void ComputeStandardDeviation();

 double Distance(double* x, double* y);

public:
 Knn(double** trainingDataset, int m, int n);
 ~Knn();
 double Vote(double* test, int k);
};

 

Knn.cpp

 

#include "Knn.h"

#include <cmath>
#include <map>

using namespace std;

Knn::Knn(double** trainingDataset, int m, int n)

{
 this->trainingDataset = trainingDataset;
 this->m = m;
 this->n = n;
 ComputeArithmeticMean();
 ComputeStandardDeviation();
 RescaleTrainingDataset();
}

void Knn::ComputeArithmeticMean()

{
 arithmeticMean = new double[n - 1];

 double sum;

 for(int i = 0; i < n - 1; i++)

 {
  sum = 0;
  for(int j = 0; j < m; j++)
  {
   sum += trainingDataset[j][i];
  }

  arithmeticMean[i] = sum / n;

 }
}

void Knn::ComputeStandardDeviation()

{
 standardDeviation = new double[n - 1];

 double sum, temp;

 for(int i = 0; i < n - 1; i++)

 {
  sum = 0;
  for(int j = 0; j < m; j++)
  {
   temp = trainingDataset[j][i] - arithmeticMean[i];
   sum += temp * temp;
  }

  standardDeviation[i] = sqrt(sum / n);

 }
}

void Knn::RescaleDistance(double* row)

{
 for(int i = 0; i < n - 1; i++)
 {
  row[i] = (row[i] - arithmeticMean[i]) / standardDeviation[i];
 }
}

void Knn::RescaleTrainingDataset()

{
 for(int i = 0; i < m; i++)
 {
  RescaleDistance(trainingDataset[i]);
 }
}

Knn::~Knn()

{
 delete[] arithmeticMean;
 delete[] standardDeviation;
}

double Knn::Distance(double* x, double* y)

{
 double sum = 0, temp;
 for(int i = 0; i < n - 1; i++)
 {
  temp = (x[i] - y[i]);
  sum += temp * temp;
 }

 return sqrt(sum);

}

double Knn::Vote(double* test, int k)

{
 RescaleDistance(test);

 double distance;

 map<int, double>::iterator max;

 map<int, double> mins;

 for(int i = 0; i < m; i++)

 {
  distance = Distance(test, trainingDataset[i]);
  if(mins.size() < k)
   mins.insert(map<int, double>::value_type(i, distance));
  else
  {
   max = mins.begin();
   for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
   {
    if(it->second > max->second)
     max = it;
   }
   if(distance < max->second)
   {
    mins.erase(max);
    mins.insert(map<int, double>::value_type(i, distance));
   }
  }
 }

 map<double, int> votes;

 double temp;

 for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)

 {
  temp = trainingDataset[it->first][n-1];
  map<double, int>::iterator voteIt = votes.find(temp);
  if(voteIt != votes.end())
   voteIt->second ++;
  else
   votes.insert(map<double, int>::value_type(temp, 1));
 }

 map<double, int>::iterator maxVote = votes.begin();

 for(map<double, int>::iterator it = votes.begin(); it != votes.end(); it++)

 {
  if(it->second > maxVote->second)
   maxVote = it;
 }

 test[n-1] = maxVote->first;

 return maxVote->first;

}

 

main.cpp

 

#include <iostream>

#include "Knn.h"

using namespace std;

int main(const int& argc, const char* argv[])

{
 double** train = new double* [14];
 for(int i = 0; i < 14; i ++)
  train[i] = new double[5];
 double trainArray[14][5] = 
 {
  {0, 0, 0, 0, 0},
  {0, 0, 0, 1, 0},
  {1, 0, 0, 0, 1},
  {2, 1, 0, 0, 1},
  {2, 2, 1, 0, 1},
  {2, 2, 1, 1, 0},
  {1, 2, 1, 1, 1},
  {0, 1, 0, 0, 0},
  {0, 2, 1, 0, 1},
  {2, 1, 1, 0, 1},
  {0, 1, 1, 1, 1},
  {1, 1, 0, 1, 1},
  {1, 0, 1, 0, 1},
  {2, 1, 0, 1, 0}
 };

 for(int i = 0; i < 14; i ++)

  for(int j = 0; j < 5; j ++)
   train[i][j] = trainArray[i][j];

 Knn knn(train, 14, 5);

 double test[5] = {2, 2, 0, 1, 0};

 cout<<knn.Vote(test, 3)<<endl;

 for(int i = 0; i < 14; i ++)

  delete[] train[i];

 delete[] train;

 return 0;

}

转载地址:http://bogho.baihongyu.com/

你可能感兴趣的文章
刘剑锋:友云采助力企业数字化采购的新发展
查看>>
Rainbond 5.0.4 发布,做最好用的云应用操作系统
查看>>
亚马逊宣布与西云数据达成合作,旨在进一步扩大中国业务
查看>>
java nio的基础--缓冲区
查看>>
负载均衡沙龙活动第二期现场问答汇集
查看>>
GBDT原理及利用GBDT构造新的特征-Python实现
查看>>
Android帧缓冲区(Frame Buffer)硬件抽象层(HAL)模块Gralloc的实现原理分析(10)...
查看>>
【Xamarin.Forms】在XAML中传递参数
查看>>
关于数据仓库 — 总体工具介绍
查看>>
最大的错误是不敢犯错
查看>>
跟我学交换机配置(七)
查看>>
makefile 中 $@ $^ % 2015-04-11 18:02:36
查看>>
C#强化系列文章三:实验分析C#中三种计时器使用异同点
查看>>
Linux 进程间通信(一)
查看>>
通用对象池ObjectPool的一种简易设计和实现方案
查看>>
HTTP压缩仍让加密连接处于风险之中
查看>>
乐视阿里达成百亿元销售框架
查看>>
戴尔通过提升大数据分析能力巩固“全数据”战略 帮助企业在现代数据经济中蓬勃发展...
查看>>
⑤Windows Server 8 RemoteFX体验
查看>>
《企业云桌面实施》-小技巧-03-vSAN6.5中SAS和SSD的使用建议
查看>>