首頁 > 科技週邊 > 人工智慧 > 使用Scikit-Learn,快速掌握機器學習預測方法

使用Scikit-Learn,快速掌握機器學習預測方法

王林
發布: 2023-05-27 14:26:03
轉載
1477 人瀏覽過

在本文中,我們將討論預測函數的差異和它們的用途。

在機器學習中,predict和predict_proba、predict_log_proba和decision_function方法都是用來根據訓練好的模型來預測的。

predict方法

使用predict方法可進行二元分類或多元分類預測,輸出預測標籤。例如,如果你已經訓練了一個邏輯迴歸模型來預測一個客戶是否會購買產品,則可以使用predict方法來預測一個新客戶是否會購買產品。

我們將使用來自scikit-learn的乳癌資料集。這個資料集包含了腫瘤觀察結果和腫瘤是惡性還是良性的相應標籤。

import numpy as npfrom sklearn.svm import SVCfrom sklearn.preprocessing import StandardScalerfrom sklearn.pipeline import make_pipelineimport matplotlib.pyplot as pltfrom sklearn.datasets import load_breast_cancer# 加载数据集dataset = load_breast_cancer(as_frame=True)# 创建特征和目标X = dataset['data']y = dataset['target']# 将数据集分割成训练集和测试集from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y , test_size=0.25, random_state=0)# 我们创建一个简单的管道来规范数据并使用`SVC`分类器训练模型svc_clf = make_pipeline(StandardScaler(),SVC(max_iter=1000, probability=True))svc_clf.fit(X_train, y_train)
登入後複製
# 我们正在预测X_test的第一个条目print(svc_clf.predict(X_test[:1]))
登入後複製
# 预测X_test的第一个条目属于哪一类[0]
登入後複製

predict_proba方法

使用predict_proba函數可以對每個類別進行機率預測,並傳回所可能的每個類別標籤的機率估計。在二元或多元分類問題中,通常採用這種方法以確定每種可能結果的機率。例如,如果你已經訓練了一個模型,將動物的圖像分為貓、狗和馬,你可以使用predict_proba方法來獲得每個類別標籤的機率估計。

print(svc_clf.predict_proba(X_test[:1]))
登入後複製
[[0.99848307 0.00151693]]
登入後複製

predict_log_proba方法

predict_log_proba方法與predict_proba類似,但它會傳回機率估計值的對數,而不是原始機率。這對處理極小或極大的機率值是十分實用的,因為可以避免數值下溢或溢出的問題。

print(svc_clf.predict_log_proba(X_test[:1]))
登入後複製
[[-1.51808474e-03 -6.49106473e+00]]
登入後複製

decision_function方法

Linear binary classification models can utilize the decision_function method.。它會針對每個輸入資料點產生一個分數,這個分數可用來推測其對應的類別標籤。可以根據應用或領域知識來設定將資料點分類為正或負的閾值。

print(svc_clf.decision_function(X_test[:1]))
登入後複製
[-1.70756057]
登入後複製

總結

  • 當你想要得​​到輸入資料的預測類別標籤時,對二元或多元分類問題使用predict。
  • 當你想要獲得每個可能的類別標籤的機率估計值時,請使用predict_proba處理二元或多元分類問題。
  • 當你需要處理非常小或非常大的機率值時,或者當你想要避免數字下溢或溢位問題時,請使用predict_log_proba。
  • 當你想要取得每個輸入資料點的分數時,使用decision_function處理線性模型的二元分類問題。

注意:某些分類器的預測方法可能不完整或需要額外參數才能存取函數。例如:SVC需要將機率參數設為True,才能使用機率預測。

以上是使用Scikit-Learn,快速掌握機器學習預測方法的詳細內容。更多資訊請關注PHP中文網其他相關文章!

相關標籤:
來源:51cto.com
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板