首頁 > 科技週邊 > 人工智慧 > 透過學習曲線辨識過擬合和欠擬合

透過學習曲線辨識過擬合和欠擬合

王林
發布: 2024-04-29 18:50:15
轉載
1469 人瀏覽過

本文將介紹如何透過學習曲線來有效辨識機器學習模型中的過度擬合和欠擬合。

透過學習曲線辨識過擬合和欠擬合

欠擬合與過擬合

#1、過擬合

如果一個模型對資料進行了過度訓練,以至於它從中學習了噪聲,那麼這個模型就被稱為過度擬合。過度擬合模型非常完美地學習了每一個例子,所以它會錯誤地分類一個看不見的/新的例子。對於一個過度擬合的模型,我們會得到一個完美/接近完美的訓練集分數和一個糟糕的驗證集/測試分數。

略有修改:「過度擬合的原因:用一個複雜的模型來解決一個簡單的問題,從資料中提取雜訊。因為小資料集作為訓練集可能無法代表所有資料的正確表示。就說它是欠擬合的。欠擬合模型並不能完全學習資料集中的每一個例子。在這種情況下,我們看到訓練集和驗證集的誤差都很低。這可能是因為模型太簡單,沒有足夠的參數來適應數據。我們可以嘗試增加模型的複雜度,增加層數或神經元的數量,來解決欠擬合問題。但要注意的是,增加模型複雜度也會增加過度擬合的風險。

不足合適的原因: 使用一個簡單的模型來解決一個複雜的問題,這個模型不能學習資料中的所有模式,或是模型錯誤的學習了底層資料的模式。 在資料分析和機器學習中,模型的選擇是非常重要的。選擇適合問題的模型可以提高預測的準確性和可靠性。對於複雜的問題,可能需要使用更複雜的模型來捕捉資料中的所有模式。另外,還需要考慮

學習曲線

學習曲線透過增量增加新的訓練樣本來繪製訓練樣本樣本本身的訓練和驗證損失。可以幫助我們確定是否需要添加額外的訓練範例來提高驗證分數(在未見過的數據上得分)。如果模型是過度擬合的,那麼添加額外的訓練範例可能會提高模型在未見過的資料上的表現。同理,如果一個模型是欠擬合的,那麼添加訓練範例也許沒有什麼用。 'learning_curve'方法可以從Scikit-Learn的'model_selection'模組導入。

from sklearn.model_selection import learning_curve
登入後複製

我們將使用邏輯迴歸和Iris資料進行示範。建立一個名為「learn_curve」的函數,它將擬合邏輯迴歸模型,並傳回交叉驗證分數、訓練分數和學習曲線資料。

#The function below builds the model and returns cross validation scores, train score and learning curve data def learn_curve(X,y,c): ''' param X: Matrix of input featuresparam y: Vector of Target/Labelc: Inverse Regularization variable to control overfitting (high value causes overfitting, low value causes underfitting)''' '''We aren't splitting the data into train and test because we will use StratifiedKFoldCV.KFold CV is a preferred method compared to hold out CV, since the model is tested on all the examples.Hold out CV is preferred when the model takes too long to train and we have a huge test set that truly represents the universe'''  le = LabelEncoder() # Label encoding the target sc = StandardScaler() # Scaling the input features y = le.fit_transform(y)#Label Encoding the target log_reg = LogisticRegression(max_iter=200,random_state=11,C=c) # LogisticRegression model # Pipeline with scaling and classification as steps, must use a pipelne since we are using KFoldCV lr = Pipeline(steps=(['scaler',sc],['classifier',log_reg]))   cv = StratifiedKFold(n_splits=5,random_state=11,shuffle=True) # Creating a StratifiedKFold object with 5 folds cv_scores = cross_val_score(lr,X,y,scoring="accuracy",cv=cv) # Storing the CV scores (accuracy) of each fold   lr.fit(X,y) # Fitting the model  train_score = lr.score(X,y) # Scoring the model on train set  #Building the learning curve train_size,train_scores,test_scores =learning_curve(estimator=lr,X=X,y=y,cv=cv,scoring="accuracy",random_state=11) train_scores = 1-np.mean(train_scores,axis=1)#converting the accuracy score to misclassification rate test_scores = 1-np.mean(test_scores,axis=1)#converting the accuracy score to misclassification rate lc =pd.DataFrame({"Training_size":train_size,"Training_loss":train_scores,"Validation_loss":test_scores}).melt(id_vars="Training_size") return {"cv_scores":cv_scores,"train_score":train_score,"learning_curve":lc}
登入後複製
上面程式碼很簡單,就是我們日常的訓練過程,下面我們開始介紹學習曲線的用處

1、擬合模型的學習曲線

我們將使用'learn_curve'函數透過將反正則化變數/參數'c'設為1來獲得一個良好的擬合模型(即我們不執行任何正則化)。

lc = learn_curve(X,y,1) print(f'Cross Validation Accuracies:\n{"-"*25}\n{list(lc["cv_scores"])}\n\n\ Mean Cross Validation Accuracy:\n{"-"*25}\n{np.mean(lc["cv_scores"])}\n\n\ Standard Deviation of Deep HUB Cross Validation Accuracy:\n{"-"*25}\n{np.std(lc["cv_scores"])}\n\n\ Training Accuracy:\n{"-"*15}\n{lc["train_score"]}\n\n') sns.lineplot(data=lc["learning_curve"],x="Training_size",y="value",hue="variable") plt.title("Learning Curve of Good Fit Model") plt.ylabel("Misclassification Rate/Loss");
登入後複製

在上面的結果中,交叉驗證準確率與訓練準確率接近。

透過學習曲線辨識過擬合和欠擬合

訓練的損失(藍色):一個好的擬合模型的學習曲線會隨著訓練範例的增加而逐漸減少並逐漸趨於平坦,說明增加更多的訓練範例並不能提升模型在訓練資料上的表現。

驗證的損失(黃色):一個好的擬合模型的學習曲線在開始時具有較高的驗證損失,隨著訓練範例的增加逐漸減少並逐漸趨於平坦,說明樣本越多,就能夠學習到更多的模式,這些模式對於”看不到“的數據會有幫助透過學習曲線辨識過擬合和欠擬合

最後還可以看到,在增加合理數量的訓練範例後,訓練損失和驗證損失彼此接近。

2、過度擬合模型的學習曲線

#我們將使用' learn_curve '函數透過將反正則化變數/參數' c '設定為10000來獲得過擬合模型(' c '的高值導致過擬合)。

lc = learn_curve(X,y,10000) print(f'Cross Validation Accuracies:\n{"-"*25}\n{list(lc["cv_scores"])}\n\n\ Mean Cross Validation Deep HUB Accuracy:\n{"-"*25}\n{np.mean(lc["cv_scores"])}\n\n\ Standard Deviation of Cross Validation Accuracy:\n{"-"*25}\n{np.std(lc["cv_scores"])} (High Variance)\n\n\ Training Accuracy:\n{"-"*15}\n{lc["train_score"]}\n\n') sns.lineplot(data=lc["learning_curve"],x="Training_size",y="value",hue="variable") plt.title("Learning Curve of an Overfit Model") plt.ylabel("Misclassification Rate/Loss");
登入後複製
#

透過學習曲線辨識過擬合和欠擬合

与拟合模型相比,交叉验证精度的标准差较高。

透過學習曲線辨識過擬合和欠擬合

过拟合模型的学习曲线一开始的训练损失很低,随着训练样例的增加,学习曲线逐渐增加,但不会变平。过拟合模型的学习曲线在开始时具有较高的验证损失,随着训练样例的增加逐渐减小并且不趋于平坦,说明增加更多的训练样例可以提高模型在未知数据上的性能。同时还可以看到,训练损失和验证损失彼此相差很远,在增加额外的训练数据时,它们可能会彼此接近。

3、欠拟合模型的学习曲线

将反正则化变量/参数' c '设置为1/10000来获得欠拟合模型(' c '的低值导致欠拟合)。

lc = learn_curve(X,y,1/10000) print(f'Cross Validation Accuracies:\n{"-"*25}\n{list(lc["cv_scores"])}\n\n\ Mean Cross Validation Accuracy:\n{"-"*25}\n{np.mean(lc["cv_scores"])}\n\n\ Standard Deviation of Cross Validation Accuracy:\n{"-"*25}\n{np.std(lc["cv_scores"])} (Low variance)\n\n\ Training Deep HUB Accuracy:\n{"-"*15}\n{lc["train_score"]}\n\n') sns.lineplot(data=lc["learning_curve"],x="Training_size",y="value",hue="variable") plt.title("Learning Curve of an Underfit Model") plt.ylabel("Misclassification Rate/Loss");
登入後複製

透過學習曲線辨識過擬合和欠擬合

与过拟合和良好拟合模型相比,交叉验证精度的标准差较低。

透過學習曲線辨識過擬合和欠擬合

欠拟合模型的学习曲线在开始时具有较低的训练损失,随着训练样例的增加逐渐增加,并在最后突然下降到任意最小点(最小并不意味着零损失)。这种最后的突然下跌可能并不总是会发生。这表明增加更多的训练样例并不能提高模型在未知数据上的性能。

总结

在机器学习和统计建模中,过拟合(Overfitting)和欠拟合(Underfitting)是两种常见的问题,它们描述了模型与训练数据的拟合程度如何影响模型在新数据上的表现。

分析生成的学习曲线时,可以关注以下几个方面:

  • 欠拟合:如果学习曲线显示训练集和验证集的性能都比较低,或者两者都随着训练样本数量的增加而缓慢提升,这通常表明模型欠拟合。这种情况下,模型可能太简单,无法捕捉数据中的基本模式。
  • 过拟合:如果训练集的性能随着样本数量的增加而提高,而验证集的性能在一定点后开始下降或停滞不前,这通常表示模型过拟合。在这种情况下,模型可能太复杂,过度适应了训练数据中的噪声而非潜在的数据模式。

根据学习曲线的分析,你可以采取以下策略进行调整:

  • 对于欠拟合
  • 增加模型复杂度,例如使用更多的特征、更深的网络或更多的参数。
  • 改善特征工程,尝试不同的特征组合或转换。
  • 增加迭代次数或调整学习率。
  • 对于过拟合
  • 使用正则化技术(如L1、L2正则化)。

  • 减少模型的复杂性,比如减少参数数量、层数或特征数量。

  • 增加更多的训练数据。

  • 应用数据增强技术。

  • 使用早停(early stopping)等技术来避免过度训练。

通过这样的分析和调整,学习曲线能够帮助你更有效地优化模型,并提高其在未知数据上的泛化能力。

以上是透過學習曲線辨識過擬合和欠擬合的詳細內容。更多資訊請關注PHP中文網其他相關文章!

來源:51cto.com
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
最新問題
PHP 必須會什麼才能找工作
來自於 1970-01-01 08:00:00
0
0
0
希望能在mac出工具!
來自於 1970-01-01 08:00:00
0
0
0
php工具箱能換mysql版本嗎?
來自於 1970-01-01 08:00:00
0
0
0
php程式設計師工具箱不能下載
來自於 1970-01-01 08:00:00
0
0
0
PHP工具箱和快表不能同時打開
來自於 1970-01-01 08:00:00
0
0
0
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板