Rumah > Peranti teknologi > AI > teks badan

Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran

王林
Lepaskan: 2024-04-29 18:50:15
ke hadapan
1297 orang telah melayarinya

Artikel ini akan memperkenalkan cara mengenal pasti pemasangan lampau dan kekurangan dalam model pembelajaran mesin secara berkesan melalui lengkung pembelajaran.

Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran

Underfitting dan overfitting

1 Overfitting

Jika model dilatih secara berlebihan pada data sehingga ia dipanggil model overfitting daripadanya. Model yang dipasang terlebih dahulu mempelajari setiap contoh dengan sempurna sehingga ia akan salah mengklasifikasikan contoh yang tidak kelihatan/baharu. Untuk model terlampau, kami akan mendapat skor set latihan yang sempurna/hampir sempurna dan set pengesahan/skor ujian yang teruk.

Sedikit diubah suai: "Punca pemasangan berlebihan: Menggunakan model kompleks untuk menyelesaikan masalah mudah mengeluarkan bunyi daripada data. Kerana set data kecil sebagai set latihan mungkin tidak mewakili perwakilan yang betul bagi semua data

2. Underfitting

Jika model tidak dapat mempelajari corak dalam data dengan betul, kami katakan ia kurang sesuai. Model kurang sesuai tidak dapat mempelajari sepenuhnya setiap contoh dalam set data. Dalam kes ini, kita melihat bahawa ralat pada kedua-dua set latihan dan pengesahan adalah rendah. Ini mungkin kerana model terlalu mudah dan tidak mempunyai parameter yang mencukupi untuk memuatkan data. Kita boleh cuba untuk meningkatkan kerumitan model, meningkatkan bilangan lapisan atau neuron, untuk menyelesaikan masalah yang tidak sesuai. Walau bagaimanapun, perlu diingatkan bahawa peningkatan kerumitan model juga meningkatkan risiko overfitting.

Sebab ia tidak sesuai: Gunakan model mudah untuk menyelesaikan masalah yang kompleks Model tidak dapat mempelajari semua corak dalam data, atau model salah mempelajari corak data asas. Dalam analisis data dan pembelajaran mesin, pemilihan model adalah sangat penting. Memilih model yang sesuai untuk masalah anda boleh meningkatkan ketepatan dan kebolehpercayaan ramalan anda. Untuk masalah yang kompleks, model yang lebih kompleks mungkin diperlukan untuk menangkap semua corak dalam data. Selain itu, anda juga perlu mempertimbangkan

Keluk Pembelajaran

Keluk pembelajaran memplot kehilangan latihan dan pengesahan sampel latihan itu sendiri dengan menambah sampel latihan baharu secara berperingkat. Boleh membantu kami menentukan sama ada kami perlu menambah contoh latihan tambahan untuk meningkatkan skor pengesahan (skor pada data yang tidak kelihatan). Jika model terlampau dipasang, maka menambah contoh latihan tambahan boleh meningkatkan prestasi model pada data yang tidak kelihatan. Begitu juga, jika model kurang sesuai, maka menambah contoh latihan mungkin tidak berguna. Kaedah 'learning_curve' boleh diimport daripada modul 'model_selection' Scikit-Learn.

from sklearn.model_selection import learning_curve
Salin selepas log masuk

Kami akan menunjukkan menggunakan regresi logistik dan data Iris. Buat fungsi yang dipanggil "learn_curve" yang sesuai dengan model regresi logistik dan mengembalikan skor pengesahan silang, skor latihan dan data lengkung pembelajaran.

#The function below builds the model and returns cross validation scores, train score and learning curve data def learn_curve(X,y,c): ''' param X: Matrix of input featuresparam y: Vector of Target/Labelc: Inverse Regularization variable to control overfitting (high value causes overfitting, low value causes underfitting)''' '''We aren't splitting the data into train and test because we will use StratifiedKFoldCV.KFold CV is a preferred method compared to hold out CV, since the model is tested on all the examples.Hold out CV is preferred when the model takes too long to train and we have a huge test set that truly represents the universe'''  le = LabelEncoder() # Label encoding the target sc = StandardScaler() # Scaling the input features y = le.fit_transform(y)#Label Encoding the target log_reg = LogisticRegression(max_iter=200,random_state=11,C=c) # LogisticRegression model # Pipeline with scaling and classification as steps, must use a pipelne since we are using KFoldCV lr = Pipeline(steps=(['scaler',sc],['classifier',log_reg]))   cv = StratifiedKFold(n_splits=5,random_state=11,shuffle=True) # Creating a StratifiedKFold object with 5 folds cv_scores = cross_val_score(lr,X,y,scoring="accuracy",cv=cv) # Storing the CV scores (accuracy) of each fold   lr.fit(X,y) # Fitting the model  train_score = lr.score(X,y) # Scoring the model on train set  #Building the learning curve train_size,train_scores,test_scores =learning_curve(estimator=lr,X=X,y=y,cv=cv,scoring="accuracy",random_state=11) train_scores = 1-np.mean(train_scores,axis=1)#converting the accuracy score to misclassification rate test_scores = 1-np.mean(test_scores,axis=1)#converting the accuracy score to misclassification rate lc =pd.DataFrame({"Training_size":train_size,"Training_loss":train_scores,"Validation_loss":test_scores}).melt(id_vars="Training_size") return {"cv_scores":cv_scores,"train_score":train_score,"learning_curve":lc}
Salin selepas log masuk

Kod di atas adalah sangat mudah, ia adalah proses latihan harian kami Sekarang kami mula memperkenalkan penggunaan keluk pembelajaran

1 gunakan fungsi 'learn_curve' Model pemasangan yang baik diperoleh dengan menetapkan pembolehubah anti-penyaturan/parameter 'c' kepada 1 (iaitu kami tidak melakukan sebarang regularisasi).

lc = learn_curve(X,y,1) print(f'Cross Validation Accuracies:\n{"-"*25}\n{list(lc["cv_scores"])}\n\n\ Mean Cross Validation Accuracy:\n{"-"*25}\n{np.mean(lc["cv_scores"])}\n\n\ Standard Deviation of Deep HUB Cross Validation Accuracy:\n{"-"*25}\n{np.std(lc["cv_scores"])}\n\n\ Training Accuracy:\n{"-"*15}\n{lc["train_score"]}\n\n') sns.lineplot(data=lc["learning_curve"],x="Training_size",y="value",hue="variable") plt.title("Learning Curve of Good Fit Model") plt.ylabel("Misclassification Rate/Loss");
Salin selepas log masuk

Dalam keputusan di atas, ketepatan pengesahan silang adalah hampir dengan ketepatan latihan. Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran

Kehilangan latihan (biru): Keluk pembelajaran model yang dipasang yang baik akan beransur-ansur berkurangan dan mendatar apabila bilangan contoh latihan bertambah, menunjukkan bahawa menambah lebih banyak contoh latihan akan Ia tidak dapat meningkatkan prestasi model pada data latihan. Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran

Kehilangan pengesahan (kuning): Keluk pembelajaran model yang dipasang dengan baik mempunyai kehilangan pengesahan yang tinggi pada permulaannya, yang secara beransur-ansur berkurangan dan mendatar apabila bilangan sampel latihan meningkat, menunjukkan bahawa semakin banyak sampel yang ada, anda boleh mempelajari lebih banyak corak, yang akan membantu untuk data "ghaib"

Akhir sekali, anda juga boleh melihat bahawa selepas menambah bilangan contoh latihan yang munasabah, kehilangan latihan dan kehilangan pengesahan adalah hampir antara satu sama lain . .

lc = learn_curve(X,y,10000) print(f'Cross Validation Accuracies:\n{"-"*25}\n{list(lc["cv_scores"])}\n\n\ Mean Cross Validation Deep HUB Accuracy:\n{"-"*25}\n{np.mean(lc["cv_scores"])}\n\n\ Standard Deviation of Cross Validation Accuracy:\n{"-"*25}\n{np.std(lc["cv_scores"])} (High Variance)\n\n\ Training Accuracy:\n{"-"*15}\n{lc["train_score"]}\n\n') sns.lineplot(data=lc["learning_curve"],x="Training_size",y="value",hue="variable") plt.title("Learning Curve of an Overfit Model") plt.ylabel("Misclassification Rate/Loss");
Salin selepas log masuk

Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran

与拟合模型相比,交叉验证精度的标准差较高。

Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran

过拟合模型的学习曲线一开始的训练损失很低,随着训练样例的增加,学习曲线逐渐增加,但不会变平。过拟合模型的学习曲线在开始时具有较高的验证损失,随着训练样例的增加逐渐减小并且不趋于平坦,说明增加更多的训练样例可以提高模型在未知数据上的性能。同时还可以看到,训练损失和验证损失彼此相差很远,在增加额外的训练数据时,它们可能会彼此接近。

3、欠拟合模型的学习曲线

将反正则化变量/参数' c '设置为1/10000来获得欠拟合模型(' c '的低值导致欠拟合)。

lc = learn_curve(X,y,1/10000) print(f'Cross Validation Accuracies:\n{"-"*25}\n{list(lc["cv_scores"])}\n\n\ Mean Cross Validation Accuracy:\n{"-"*25}\n{np.mean(lc["cv_scores"])}\n\n\ Standard Deviation of Cross Validation Accuracy:\n{"-"*25}\n{np.std(lc["cv_scores"])} (Low variance)\n\n\ Training Deep HUB Accuracy:\n{"-"*15}\n{lc["train_score"]}\n\n') sns.lineplot(data=lc["learning_curve"],x="Training_size",y="value",hue="variable") plt.title("Learning Curve of an Underfit Model") plt.ylabel("Misclassification Rate/Loss");
Salin selepas log masuk

Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran

与过拟合和良好拟合模型相比,交叉验证精度的标准差较低。

Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran

欠拟合模型的学习曲线在开始时具有较低的训练损失,随着训练样例的增加逐渐增加,并在最后突然下降到任意最小点(最小并不意味着零损失)。这种最后的突然下跌可能并不总是会发生。这表明增加更多的训练样例并不能提高模型在未知数据上的性能。

总结

在机器学习和统计建模中,过拟合(Overfitting)和欠拟合(Underfitting)是两种常见的问题,它们描述了模型与训练数据的拟合程度如何影响模型在新数据上的表现。

分析生成的学习曲线时,可以关注以下几个方面:

  • 欠拟合:如果学习曲线显示训练集和验证集的性能都比较低,或者两者都随着训练样本数量的增加而缓慢提升,这通常表明模型欠拟合。这种情况下,模型可能太简单,无法捕捉数据中的基本模式。
  • 过拟合:如果训练集的性能随着样本数量的增加而提高,而验证集的性能在一定点后开始下降或停滞不前,这通常表示模型过拟合。在这种情况下,模型可能太复杂,过度适应了训练数据中的噪声而非潜在的数据模式。

根据学习曲线的分析,你可以采取以下策略进行调整:

  • 对于欠拟合
  • 增加模型复杂度,例如使用更多的特征、更深的网络或更多的参数。
  • 改善特征工程,尝试不同的特征组合或转换。
  • 增加迭代次数或调整学习率。
  • 对于过拟合
  • 使用正则化技术(如L1、L2正则化)。

  • 减少模型的复杂性,比如减少参数数量、层数或特征数量。

  • 增加更多的训练数据。

  • 应用数据增强技术。

  • 使用早停(early stopping)等技术来避免过度训练。

通过这样的分析和调整,学习曲线能够帮助你更有效地优化模型,并提高其在未知数据上的泛化能力。

Atas ialah kandungan terperinci Kenal pasti overfitting dan underfitting melalui lengkung pembelajaran. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

sumber:51cto.com
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan
Tentang kita Penafian Sitemap
Laman web PHP Cina:Latihan PHP dalam talian kebajikan awam,Bantu pelajar PHP berkembang dengan cepat!