為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎
對於二元分類,分類器輸出一個實值分數,然後透過對該值進行閾值的區分產生二元的對應。例如,邏輯迴歸輸出一個機率(一個介於0.0和1.0之間的值);得分等於或高於0.5的觀察結果產生正輸出(許多其他模型預設使用0.5閾值)。
但是使用預設的0.5閾值是不理想的。在本文中,我將展示如何從二元分類器中選擇最佳閾值。本文將使用Ploomber並行執行我們的實驗,並使用sklearn-evaluation產生圖。
這裡以訓練邏輯迴歸為例。假設我們正在開發一個內容審核系統,模型標記包含有害內容的貼文(圖片、影片等);然後,人工會查看並決定內容是否被刪除。
建立簡單的二元分類器
下面的程式碼片段訓練我們的分類器:
import matplotlib.pyplot as plt import matplotlib as mpl from sklearn import datasets from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn_evaluation.plot import ConfusionMatrix # matplotlib settings mpl.rcParams['figure.figsize'] = (4, 4) mpl.rcParams['figure.dpi'] = 150 # create sample dataset X, y = datasets.make_classification(1000, 10, n_informative=5, class_sep=0.4) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) # fit model clf = LogisticRegression() _ = clf.fit(X_train, y_train)
現在讓我們對測試集進行預測,並透過混淆矩陣評估性能:
# predict on the test set y_pred = clf.predict(X_test) # plot confusion matrix cm_dot_five = ConfusionMatrix(y_test, y_pred) cm_dot_five
混淆矩陣總結了模型在四個區域的性能:
y_score = clf.predict_proba(X_test)
cm_dot_four = ConfusionMatrix(y_score[:, 1] >= 0.4, y_pred)
cm_dot_five + cm_dot_four
- 兩個模型對相同數量的觀測結果都預測為0(這是一個巧合)。 0.5閾值:(90 56 = 146)。 0.4閾值:(78 68 = 146)降低閾值會導致更多的假陰性(從56例降至68例)降低閾值將大大增加真陽性(從92例增加154例)
curl -O https://raw.githubusercontent.com/ploomber/posts/master/threshold/fit.ipynb?utm_source=medium&utm_medium=blog&utm_campaign=threshold
ploomber cloud nb fit.ipynb
ploomber cloud status @latest --summary status count -------- ------- finished 20 Pipeline finished. Check outputs: $ ploomber cloud products
ploomber cloud download 'threshold-selection/*.csv' --summary
from glob import glob import pandas as pd import numpy as np paths = glob('threshold-selection/**/*.csv') metrics = [pd.read_csv(path) for path in paths] for idx, df in enumerate(metrics): plt.plot(df.threshold, df.precision, color='blue', alpha=0.2, label='precision' if idx == 0 else None) plt.plot(df.threshold, df.recall, color='green', alpha=0.2, label='recall' if idx == 0 else None) plt.plot(df.threshold, df.f1, color='orange', alpha=0.2, label='f1' if idx == 0 else None) plt.grid() plt.legend() plt.xlabel('Threshold') plt.ylabel('Metric value') for handle in plt.legend().legendHandles: handle.set_alpha(1) ax = plt.twinx() for idx, df in enumerate(metrics): ax.plot(df.threshold, df.n_flagged, label='flagged' if idx == 0 else None, color='red', alpha=0.2) plt.ylabel('Flagged') ax.legend(loc=0) ax.legend().legendHandles[0].set_alpha(1)
#
左边的刻度(从0到1)是我们的三个指标:精度、召回率和F1。F1分为精度与查全率的调和平均值,F1分的最佳值为1.0,最差值为0.0;F1对精度和召回率都是相同对待的,所以你可以看到它在两者之间保持平衡。如果你正在处理一个精确度和召回率都很重要的用例,那么最大化F1是一种可以帮助你优化分类器阈值的方法。
这里还包括一条红色曲线(右侧的比例),显示我们的模型标记为有害内容的案例数量。
在这个的内容审核示例中,可能有X个的工作人员来人工审核模型标记的有害帖子,但是他们人数是有限的,因此考虑标记帖子的总数可以帮助我们更好地选择阈值:例如每天只能检查5000个帖子,那么模型找到10,000帖并不会带来任何的提高。如果我人工每天可以处理10000贴,但是模型只标记了100贴,那么显然也是浪费的。
当设置较低的阈值时,有较高的召回率(我们检索了大部分实际上有害的帖子),但精度较低(包含了许多无害的帖子)。如果我们提高阈值,情况就会反转:召回率下降(错过了许多有害的帖子),但精确度很高(大多数标记的帖子都是有害的)。
所以在为我们的二元分类器选择阈值时,我们必须在精度或召回率上妥协,因为没有一个分类器是完美的。我们来讨论一下如何推理选择合适的阈值。
选择最佳阈值
右边的数据会产生噪声(较大的阈值)。需要稍微清理一下,我们将重新创建这个图,我们将绘制2.5%、50%和97.5%的百分位数,而不是绘制所有值。
shape = (df.shape[0], len(metrics)) precision = np.zeros(shape) recall = np.zeros(shape) f1 = np.zeros(shape) n_flagged = np.zeros(shape) for i, df in enumerate(metrics): precision[:, i] = df.precision.values recall[:, i] = df.recall.values f1[:, i] = df.f1.values n_flagged[:, i] = df.n_flagged.values precision_ = np.quantile(precision, q=0.5, axis=1) recall_ = np.quantile(recall, q=0.5, axis=1) f1_ = np.quantile(f1, q=0.5, axis=1) n_flagged_ = np.quantile(n_flagged, q=0.5, axis=1) plt.plot(df.threshold, precision_, color='blue', label='precision') plt.plot(df.threshold, recall_, color='green', label='recall') plt.plot(df.threshold, f1_, color='orange', label='f1') plt.fill_between(df.threshold, precision_interval[0], precision_interval[1], color='blue', alpha=0.2) plt.fill_between(df.threshold, recall_interval[0], recall_interval[1], color='green', alpha=0.2) plt.fill_between(df.threshold, f1_interval[0], f1_interval[1], color='orange', alpha=0.2) plt.xlabel('Threshold') plt.ylabel('Metric value') plt.legend() ax = plt.twinx() ax.plot(df.threshold, n_flagged_, color='red', label='flagged') ax.fill_between(df.threshold, n_flagged_interval[0], n_flagged_interval[1], color='red', alpha=0.2) ax.legend(loc=3) plt.ylabel('Flagged') plt.grid()
我们可以根据自己的需求选择阈值,例如检索尽可能多的有害帖子(高召回率)是否更重要?还是要有更高的确定性,我们标记的必须是有害的(高精度)?
如果两者都同等重要,那么在这些条件下优化的常用方法就是最大化F-1分数:
idx = np.argmax(f1_) prec_lower, prec_upper = precision_interval[0][idx], precision_interval[1][idx] rec_lower, rec_upper = recall_interval[0][idx], recall_interval[1][idx] threshold = df.threshold[idx] print(f'Max F1 score: {f1_[idx]:.2f}') print('Metrics when maximizing F1 score:') print(f' - Threshold: {threshold:.2f}') print(f' - Precision range: ({prec_lower:.2f}, {prec_upper:.2f})') print(f' - Recall range: ({rec_lower:.2f}, {rec_upper:.2f})') #结果 Max F1 score: 0.71 Metrics when maximizing F1 score: - Threshold: 0.26 - Precision range: (0.58, 0.61) - Recall range: (0.86, 0.90)
在很多情况下很难决定这个折中,所以加入一些约束条件会有一些帮助。
假设我们有10个人审查有害的帖子,他们可以一起检查5000个。那么让我们看看指标,如果我们修改了阈值,让它标记了大约5000个帖子:
idx = np.argmax(n_flagged_ <= 5000) prec_lower, prec_upper = precision_interval[0][idx], precision_interval[1][idx] rec_lower, rec_upper = recall_interval[0][idx], recall_interval[1][idx] threshold = df.threshold[idx] print('Metrics when limiting to a maximum of 5,000 flagged events:') print(f' - Threshold: {threshold:.2f}') print(f' - Precision range: ({prec_lower:.2f}, {prec_upper:.2f})') print(f' - Recall range: ({rec_lower:.2f}, {rec_upper:.2f})') # 结果 Metrics when limiting to a maximum of 5,000 flagged events: - Threshold: 0.82 - Precision range: (0.77, 0.81) - Recall range: (0.25, 0.36)
如果需要进行汇报,我们可以在在展示结果时展示一些替代方案:比如在当前约束条件下(5000个帖子)的模型性能,以及如果我们增加团队(比如通过增加一倍的规模),我们可以做得更好。
总结
二元分类器的最佳阈值是针对业务结果进行优化并考虑到流程限制的阈值。通过本文中描述的过程,你可以更好地为用例决定最佳阈值。
另外,Ploomber Cloud!提供一些免费的算力!如果你需要一些免费的服务可以试试它。
以上是為機器學習模型設定最佳閾值:0.5是二元分類的最佳閾值嗎的詳細內容。更多資訊請關注PHP中文網其他相關文章!

熱AI工具

Undresser.AI Undress
人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover
用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool
免費脫衣圖片

Clothoff.io
AI脫衣器

Video Face Swap
使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱門文章

熱工具

記事本++7.3.1
好用且免費的程式碼編輯器

SublimeText3漢化版
中文版,非常好用

禪工作室 13.0.1
強大的PHP整合開發環境

Dreamweaver CS6
視覺化網頁開發工具

SublimeText3 Mac版
神級程式碼編輯軟體(SublimeText3)

圖像標註是將標籤或描述性資訊與圖像相關聯的過程,以賦予圖像內容更深層的含義和解釋。這個過程對於機器學習至關重要,它有助於訓練視覺模型以更準確地識別圖像中的各個元素。透過為圖像添加標註,使得電腦能夠理解圖像背後的語義和上下文,從而提高對圖像內容的理解和分析能力。影像標註的應用範圍廣泛,涵蓋了許多領域,如電腦視覺、自然語言處理和圖視覺模型具有廣泛的應用領域,例如,輔助車輛識別道路上的障礙物,幫助疾病的檢測和診斷透過醫學影像識別。本文主要推薦一些較好的開源免費的圖片標註工具。 1.Makesens

在機器學習和資料科學領域,模型的可解釋性一直是研究者和實踐者關注的焦點。隨著深度學習和整合方法等複雜模型的廣泛應用,理解模型的決策過程變得尤為重要。可解釋人工智慧(ExplainableAI|XAI)透過提高模型的透明度,幫助建立對機器學習模型的信任和信心。提高模型的透明度可以透過多種複雜模型的廣泛應用等方法來實現,以及用於解釋模型的決策過程。這些方法包括特徵重要性分析、模型預測區間估計、局部可解釋性演算法等。特徵重要性分析可以透過評估模型對輸入特徵的影響程度來解釋模型的決策過程。模型預測區間估計

本文將介紹如何透過學習曲線來有效辨識機器學習模型中的過度擬合和欠擬合。欠擬合和過擬合1、過擬合如果一個模型對資料進行了過度訓練,以至於它從中學習了噪聲,那麼這個模型就被稱為過擬合。過度擬合模型非常完美地學習了每一個例子,所以它會錯誤地分類一個看不見的/新的例子。對於一個過度擬合的模型,我們會得到一個完美/接近完美的訓練集分數和一個糟糕的驗證集/測試分數。略有修改:"過擬合的原因:用一個複雜的模型來解決一個簡單的問題,從資料中提取雜訊。因為小資料集作為訓練集可能無法代表所有資料的正確表示。"2、欠擬合如

1950年代,人工智慧(AI)誕生。當時研究人員發現機器可以執行類似人類的任務,例如思考。後來,在1960年代,美國國防部資助了人工智慧,並建立了實驗室進行進一步開發。研究人員發現人工智慧在許多領域都有用武之地,例如太空探索和極端環境中的生存。太空探索是對宇宙的研究,宇宙涵蓋了地球以外的整個宇宙空間。太空被歸類為極端環境,因為它的條件與地球不同。要在太空中生存,必須考慮許多因素,並採取預防措施。科學家和研究人員認為,探索太空並了解一切事物的現狀有助於理解宇宙的運作方式,並為潛在的環境危機

通俗來說,機器學習模型是一種數學函數,它能夠將輸入資料映射到預測輸出。更具體地說,機器學習模型是一種透過學習訓練數據,來調整模型參數,以最小化預測輸出與真實標籤之間的誤差的數學函數。在機器學習中存在多種模型,例如邏輯迴歸模型、決策樹模型、支援向量機模型等,每種模型都有其適用的資料類型和問題類型。同時,不同模型之間存在著許多共通性,或者說有一條隱藏的模型演化的路徑。將聯結主義的感知機為例,透過增加感知機的隱藏層數量,我們可以將其轉化為深度神經網路。而對感知機加入核函數的話就可以轉換為SVM。這一

C++中機器學習演算法面臨的常見挑戰包括記憶體管理、多執行緒、效能最佳化和可維護性。解決方案包括使用智慧指標、現代線程庫、SIMD指令和第三方庫,並遵循程式碼風格指南和使用自動化工具。實作案例展示如何利用Eigen函式庫實現線性迴歸演算法,有效地管理記憶體和使用高效能矩陣操作。

機器學習是人工智慧的重要分支,它賦予電腦從數據中學習的能力,並能夠在無需明確編程的情況下改進自身能力。機器學習在各個領域都有廣泛的應用,從影像辨識和自然語言處理到推薦系統和詐欺偵測,它正在改變我們的生活方式。機器學習領域存在著多種不同的方法和理論,其中最具影響力的五種方法被稱為「機器學習五大派」。這五大派分別為符號派、聯結派、進化派、貝葉斯派和類推學派。 1.符號學派符號學(Symbolism),又稱符號主義,強調利用符號進行邏輯推理和表達知識。該學派認為學習是一種逆向演繹的過程,透過現有的

MetaFAIR聯合哈佛優化大規模機器學習時所產生的資料偏差,提供了新的研究架構。據所周知,大語言模型的訓練常常需要數月的時間,使用數百甚至上千個GPU。以LLaMA270B模型為例,其訓練總共需要1,720,320個GPU小時。由於這些工作負載的規模和複雜性,導致訓練大模型存在著獨特的系統性挑戰。最近,許多機構在訓練SOTA生成式AI模型時報告了訓練過程中的不穩定情況,它們通常以損失尖峰的形式出現,例如Google的PaLM模型訓練過程中出現了多達20次的損失尖峰。數值偏差是造成這種訓練不準確性的根因,
