首页 > 后端开发 > Python教程 > Python中的支持向量机算法实例

Python中的支持向量机算法实例

王林
发布: 2023-06-10 16:42:14
原创
1437 人浏览过

支持向量机,英文全称为Support Vecto Machines,简称SVM。它是一种非常优秀的分类模型,特别在小样本、非线性以及高维模式识别中有很好的表现。SVM是由Vapnik团队在1992年提出,最初被用来解决二分类问题,后来逐渐发展成为可以处理多分类问题的算法。

Python是一种简洁而强大的编程语言,它实现了众多机器学习算法的包,其中包括SVM。本文将介绍通过Python实现支持向量机算法的步骤。

一、准备数据

我们来构造一组简单的训练数据。创建一个示例数据集,其中x1表示身高,x2表示体重,y为类别标签(0或1)。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(7)
X_train = np.array([[167, 75], [182, 80], [176, 85], [156, 50], [173, 70], [183, 90], [178, 75], [156, 45],
                    [162, 55], [163, 50], [159, 45], [180, 85]])
y_train = np.array([0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1])
plt.scatter(X_train[y_train == 0][:, 0], X_train[y_train == 0][:, 1], c='r', s=40, label='Male')
plt.scatter(X_train[y_train == 1][:, 0], X_train[y_train == 1][:, 1], c='b', s=40, label='Female')
plt.legend()
plt.xlabel('Height')
plt.ylabel('Weight')
plt.show()
登录后复制

在这个数据集中,我们将人群分类为男性或女性。

二、选择分类器

接下来,我们要选择适用于这个问题的分类器,即SVM。SVM有许多变种,但是在这里,我们使用的是线性SVM。

我们来构造一个SVM模型:

from sklearn.svm import SVC

svm = SVC(kernel='linear')
svm.fit(X_train, y_train)
登录后复制

在这里,我们使用的是SVC类,指定kernel参数为linear,表明我们使用线性核。

三、绘制决策边界

我们想要知道模型的性能如何,因此我们可以在绘制出分类器的决策边界:

def plot_decision_boundary(model, ax=None):
    if ax is None:
        ax = plt.gca()
    x_min, x_max = ax.get_xlim()
    y_min, y_max = ax.get_ylim()
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
                         np.linspace(y_min, y_max, 100))
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
    ax.contourf(xx, yy, Z, alpha=0.2)
    ax.contour(xx, yy, Z, colors='black', linewidths=0.5)
    ax.set_xlim([x_min, x_max])
    ax.set_ylim([y_min, y_max])
    
plt.scatter(X_train[y_train == 0][:, 0], X_train[y_train == 0][:, 1], c='r', s=40, label='Male')
plt.scatter(X_train[y_train == 1][:, 0], X_train[y_train == 1][:, 1], c='b', s=40, label='Female')
plot_decision_boundary(svm)
plt.legend()
plt.xlabel('Height')
plt.ylabel('Weight')
plt.show()
登录后复制

运行结束后,可以看到绘制出了分类器的决策边界。

四、预测新数据

我们可以用训练好的模型对新的数据进行预测。

X_test = np.array([[166, 70], [185, 90], [170, 75]])
y_test = svm.predict(X_test)
print(y_test)
登录后复制

在这里,我们使用predict函数对三个新数据样本进行预测。它将返回它们的类别。

结论

在这篇文章中,我们介绍了如何使用Python中的支持向量机算法。我们通过创建一个简单的训练数据集,并使用线性SVM构建了一个分类器。我们还绘制了分类器的决策边界,并使用模型来预测了新的数据样本。SVM在许多场合也是非常受欢迎的算法,可以在很多领域获得好的表现。如果你想在处理数据时,掌握更多机器学习的算法,那么SVM也是非常值得学习的。

以上是Python中的支持向量机算法实例的详细内容。更多信息请关注PHP中文网其他相关文章!

来源:php.cn
本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板