Home > Database > Mysql Tutorial > body text

利用SVM解决2维空间向量的3级分类问题

WBOY
Release: 2016-06-07 15:43:03
Original
1624 people have browsed it

【原文:http://blog.csdn.net/firefight/article/details/6400060】 为了学习OPENCV SVM分类器, 参考网上的 利用SVM解决2维空间向量的分类问题 实现并改为C代码,仅供参考 环境:OPENCV2.2 VS2008 步骤: 1,生成随机的点,并按一定的空间分布将其归类 2,

【原文:http://blog.csdn.net/firefight/article/details/6400060】

为了学习OPENCV SVM分类器, 参考网上的"利用SVM解决2维空间向量的分类问题"实现并改为C++代码,仅供参考

 

环境:OPENCV2.2 + VS2008

步骤:
1,生成随机的点,并按一定的空间分布将其归类
2,创建SVM并利用随机点样本进行训练
3,将整个空间按SVM分类结果进行划分,并显示支持向量

 

[cpp] view plaincopy

  1. #include "stdafx.h"  
  2. #include   
  3.   
  4. void drawCross(Mat &img, Point center, Scalar color)  
  5. {  
  6.     int col = center.x > 2 ? center.x : 2;  
  7.     int row = center.y> 2 ? center.y : 2;  
  8.   
  9.     line(img, Point(col -2, row - 2), Point(col + 2, row + 2), color);    
  10.     line(img, Point(col + 2, row - 2), Point(col - 2, row + 2), color);    
  11. }  
  12.   
  13. int newSvmTest(int rows, int cols, int testCount)  
  14. {  
  15.     if(testCount > rows * cols)  
  16.         return 0;  
  17.   
  18.     Mat img = Mat::zeros(rows, cols, CV_8UC3);  
  19.     Mat testPoint = Mat::zeros(rows, cols, CV_8UC1);  
  20.     Mat data = Mat::zeros(testCount, 2, CV_32FC1);  
  21.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
  22.   
  23.     //Create random test points  
  24.     for (int i= 0; i
  25.     {   
  26.         int row = rand() % rows;  
  27.         int col = rand() % cols;  
  28.   
  29.         if(testPoint.atchar>(row, col) == 0)  
  30.         {  
  31.             testPoint.atchar>(row, col) = 1;  
  32.             data.atfloat>(i, 0) = float (col) / cols;   
  33.             data.atfloat>(i, 1) = float (row) / rows;   
  34.         }  
  35.         else  
  36.         {  
  37.             i--;  
  38.             continue;  
  39.         }  
  40.   
  41.         if (row > ( 50 * cos(col * CV_PI/ 100) + 200) )  
  42.         {   
  43.             drawCross(img, Point(col, row), CV_RGB(255, 0, 0));  
  44.             res.atint>(i, 0) = 1;   
  45.         }   
  46.         else   
  47.         {   
  48.             if (col > 200)   
  49.             {   
  50.                 drawCross(img, Point(col, row), CV_RGB(0, 255, 0));  
  51.                 res.atint>(i, 0) = 2;   
  52.             }   
  53.             else   
  54.             {   
  55.                 drawCross(img, Point(col, row), CV_RGB(0, 0, 255));  
  56.                 res.atint>(i, 0) = 3;   
  57.             }   
  58.         }   
  59.   
  60.     }  
  61.   
  62.     //Show test points  
  63.     imshow("dst", img);  
  64.     waitKey(0);  
  65.   
  66.     /////////////START SVM TRAINNING//////////////////  
  67.     CvSVM svm = CvSVM();   
  68.     CvSVMParams param;   
  69.     CvTermCriteria criteria;  
  70.   
  71.     criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);  
  72. /* SVM种类:CvSVM::C_SVC 
    Kernel的种类:CvSVM::RBF
    degree:10.0(此次不使用) 
    gamma:8.0 
    coef0:1.0(此次不使用)
    C:10.0 
    nu:0.5(此次不使用) 
    p:0.1(此次不使用) 
    然后对训练数据正规化处理,并放在CvMat型的数组里。*/

  73.     param= CvSVMParams (CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);   
  74.     svm.train(data, res, Mat(), Mat(), param);  
  75.   
  76.     for (int i= 0; i
  77.     {   
  78.         for (int j= 0; j
  79.         {   
  80.             Mat m = Mat::zeros(1, 2, CV_32FC1);  
  81.             m.atfloat>(0,0) = float (j) / cols;  
  82.             m.atfloat>(0,1) = float (i) / rows;  
  83.   
  84.             float ret = 0.0;   
  85.             ret = svm.predict(m);   
  86.             Scalar rcolor;   
  87.   
  88.             switch ((int) ret)   
  89.             {   
  90.                 case 1: rcolor= CV_RGB(100, 0, 0); break;   
  91.                 case 2: rcolor= CV_RGB(0, 100, 0); break;   
  92.                 case 3: rcolor= CV_RGB(0, 0, 100); break;   
  93.             }   
  94.   
  95.             line(img, Point(j,i), Point(j,i), rcolor);  
  96.         }   
  97.     }  
  98.   
  99.     imshow("dst", img);  
  100.     waitKey(0);  
  101.   
  102.     //Show support vectors  
  103.     int sv_num= svm.get_support_vector_count();   
  104.     for (int i= 0; i
  105.     {   
  106.         const float* support = svm.get_support_vector(i);   
  107.         circle(img, Point((int) (support[0] * cols), (int) (support[1] * rows)), 5, CV_RGB(200, 200, 200));   
  108.     }  
  109.   
  110.     imshow("dst", img);  
  111.     waitKey(0);  
  112.   
  113.     return 0;  
  114. }  
  115.   
  116. int main(int argc, char** argv)  
  117. {  
  118.     return newSvmTest(400, 600, 100);  
  119. }  

 

学习样本:

利用SVM解决2维空间向量的3级分类问题

 

分类:

利用SVM解决2维空间向量的3级分类问题

 

支持向量:

利用SVM解决2维空间向量的3级分类问题

source:php.cn
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template