详解PyTorch批训练及优化器比较
本篇文章主要介绍了详解PyTorch批训练及优化器比较,详细的介绍了什么是PyTorch批训练和PyTorch的Optimizer优化器,非常具有实用价值,需要的朋友可以参考下
一、PyTorch批训练
1. 概述
PyTorch提供了一种将数据包装起来进行批训练的工具——DataLoader。使用的时候,只需要将我们的数据首先转换为torch的tensor形式,再转换成torch可以识别的Dataset格式,然后将Dataset放入DataLoader中就可以啦。
import torch import torch.utils.data as Data torch.manual_seed(1) # 设定随机数种子 BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) y = torch.linspace(0.5, 5, 10) # 将数据转换为torch的dataset格式 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) # 将torch_dataset置入Dataloader中 loader = Data.DataLoader( dataset=torch_dataset, batch_size=BATCH_SIZE, # 批大小 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少 shuffle=True, # 是否随机打乱顺序 num_workers=2, # 多线程读取数据的线程数 ) for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): print('Epoch:', epoch, '|Step:', step, '|batch_x:', batch_x.numpy(), '|batch_y', batch_y.numpy()) ''''' shuffle=True Epoch: 0 |Step: 0 |batch_x: [ 6. 7. 2. 3. 1.] |batch_y [ 3. 3.5 1. 1.5 0.5] Epoch: 0 |Step: 1 |batch_x: [ 9. 10. 4. 8. 5.] |batch_y [ 4.5 5. 2. 4. 2.5] Epoch: 1 |Step: 0 |batch_x: [ 3. 4. 2. 9. 10.] |batch_y [ 1.5 2. 1. 4.5 5. ] Epoch: 1 |Step: 1 |batch_x: [ 1. 7. 8. 5. 6.] |batch_y [ 0.5 3.5 4. 2.5 3. ] Epoch: 2 |Step: 0 |batch_x: [ 3. 9. 2. 6. 7.] |batch_y [ 1.5 4.5 1. 3. 3.5] Epoch: 2 |Step: 1 |batch_x: [ 10. 4. 8. 1. 5.] |batch_y [ 5. 2. 4. 0.5 2.5] shuffle=False Epoch: 0 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 0 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] Epoch: 1 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 1 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] Epoch: 2 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 2 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] '''
2. TensorDataset
classtorch.utils.data.TensorDataset(data_tensor, target_tensor)
TensorDataset类用来将样本及其标签打包成torch的Dataset,data_tensor,和target_tensor都是tensor。
3. DataLoader
复制代码 代码如下:
classtorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,num_workers=0, collate_fn=
dataset就是Torch的Dataset格式的对象;batch_size即每批训练的样本数量,默认为;shuffle表示是否需要随机取样本;num_workers表示读取样本的线程数。
二、PyTorch的Optimizer优化器
本实验中,首先构造一组数据集,转换格式并置于DataLoader中,备用。定义一个固定结构的默认神经网络,然后为每个优化器构建一个神经网络,每个神经网络的区别仅仅是优化器不同。通过记录训练过程中的loss值,最后在图像上呈现得到各个优化器的优化过程。
代码实现:
import torch import torch.utils.data as Data import torch.nn.functional as F from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed(1) # 设定随机数种子 # 定义超参数 LR = 0.01 # 学习率 BATCH_SIZE = 32 # 批大小 EPOCH = 12 # 迭代次数 x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1) y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size())) #plt.scatter(x.numpy(), y.numpy()) #plt.show() # 将数据转换为torch的dataset格式 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) # 将torch_dataset置入Dataloader中 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.hidden = torch.nn.Linear(1, 20) self.predict = torch.nn.Linear(20, 1) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x # 为每个优化器创建一个Net net_SGD = Net() net_Momentum = Net() net_RMSprop = Net() net_Adam = Net() nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam] # 初始化优化器 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR) opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8) opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9) opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99)) optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam] # 定义损失函数 loss_function = torch.nn.MSELoss() losses_history = [[], [], [], []] # 记录training时不同神经网络的loss值 for epoch in range(EPOCH): print('Epoch:', epoch + 1, 'Training...') for step, (batch_x, batch_y) in enumerate(loader): b_x = Variable(batch_x) b_y = Variable(batch_y) for net, opt, l_his in zip(nets, optimizers, losses_history): output = net(b_x) loss = loss_function(output, b_y) opt.zero_grad() loss.backward() opt.step() l_his.append(loss.data[0]) labels = ['SGD', 'Momentum', 'RMSprop', 'Adam'] for i, l_his in enumerate(losses_history): plt.plot(l_his, label=labels[i]) plt.legend(loc='best') plt.xlabel('Steps') plt.ylabel('Loss') plt.ylim((0, 0.2)) plt.show()
实验结果:
由实验结果可见,SGD的优化效果是最差的,速度很慢;作为SGD的改良版本,Momentum表现就好许多;相比RMSprop和Adam的优化速度就非常好。实验中,针对不同的优化问题,比较各个优化器的效果再来决定使用哪个。
三、其他补充
1. Python的zip函数
zip函数接受任意多个(包括0个和1个)序列作为参数,返回一个tuple列表。
x = [1, 2, 3] y = [4, 5, 6] z = [7, 8, 9] xyz = zip(x, y, z) print xyz [(1, 4, 7), (2, 5, 8), (3, 6, 9)] x = [1, 2, 3] x = zip(x) print x [(1,), (2,), (3,)] x = [1, 2, 3] y = [4, 5, 6, 7] xy = zip(x, y) print xy [(1, 4), (2, 5), (3, 6)]
相关推荐:
Atas ialah kandungan terperinci 详解PyTorch批训练及优化器比较. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Alat AI Hot

Undresser.AI Undress
Apl berkuasa AI untuk mencipta foto bogel yang realistik

AI Clothes Remover
Alat AI dalam talian untuk mengeluarkan pakaian daripada foto.

Undress AI Tool
Gambar buka pakaian secara percuma

Clothoff.io
Penyingkiran pakaian AI

Video Face Swap
Tukar muka dalam mana-mana video dengan mudah menggunakan alat tukar muka AI percuma kami!

Artikel Panas

Alat panas

Notepad++7.3.1
Editor kod yang mudah digunakan dan percuma

SublimeText3 versi Cina
Versi Cina, sangat mudah digunakan

Hantar Studio 13.0.1
Persekitaran pembangunan bersepadu PHP yang berkuasa

Dreamweaver CS6
Alat pembangunan web visual

SublimeText3 versi Mac
Perisian penyuntingan kod peringkat Tuhan (SublimeText3)

Topik panas

Pada masa kini, prestasi dan fungsi telefon bimbit semakin berkuasa Hampir semua telefon bimbit dilengkapi dengan fungsi NFC yang mudah untuk memudahkan pengguna untuk pembayaran mudah alih dan pengesahan identiti. Walau bagaimanapun, sesetengah pengguna Xiaomi 14Pro mungkin tidak tahu cara mendayakan fungsi NFC. Seterusnya, izinkan saya memperkenalkannya kepada anda secara terperinci. Bagaimana untuk mendayakan fungsi nfc pada Xiaomi 14Pro? Langkah 1: Buka menu tetapan telefon anda. Langkah 2: Cari dan klik pilihan "Sambung dan Kongsi" atau "Wayarles & Rangkaian". Langkah 3: Dalam menu Sambungan & Perkongsian atau Wayarles & Rangkaian, cari dan klik "NFC & Pembayaran". Langkah 4: Cari dan klik "NFC Switch". Biasanya, lalai dimatikan. Langkah 5: Pada halaman suis NFC, klik butang suis untuk menghidupkannya.

Meluncur skrin melalui udara adalah ciri Huawei yang sangat dipuji dalam siri Huawei mate60 Ciri ini menggunakan sensor laser pada telefon dan kamera kedalaman 3D kamera hadapan untuk melengkapkan siri fungsi yang tidak memerlukan The. fungsi menyentuh skrin, seperti meleret TikTok dari udara, tetapi bagaimana menggunakan Huawei Pocket 2 untuk meleret TikTok dari udara? Bagaimana untuk mengambil tangkapan skrin dari udara dengan Huawei Pocket2? 1. Buka tetapan Huawei Pocket2 2. Kemudian pilih [Kebolehcapaian]. 3. Klik untuk membuka [Persepsi Pintar]. 4. Hanya hidupkan suis [Air Swipe Screen], [Air Screenshot] dan [Air Press]. 5. Apabila menggunakannya, anda perlu menahannya 20~40CM dari skrin, buka tapak tangan anda dan tunggu sehingga ikon tapak tangan muncul pada skrin.

WPS ialah perisian pejabat kami yang biasa digunakan Semasa mengedit artikel panjang, fon selalunya terlalu kecil untuk dilihat dengan jelas, jadi fon dan keseluruhan dokumen dilaraskan. Sebagai contoh: melaraskan jarak baris dokumen akan menjadikan keseluruhan dokumen sangat jelas. Saya cadangkan agar semua rakan mempelajari langkah operasi ini, saya akan berkongsi dengan anda hari ini. Buka fail teks WPS yang anda ingin laraskan, cari bar alat tetapan perenggan dalam menu [Mula], dan anda akan melihat ikon tetapan jarak baris kecil (ditunjukkan sebagai bulatan merah dalam gambar). 2. Klik segi tiga terbalik kecil di sudut kanan bawah tetapan jarak baris, dan nilai jarak baris yang sepadan akan muncul Anda boleh memilih 1 hingga 3 kali jarak baris (seperti yang ditunjukkan oleh anak panah dalam rajah). 3. Atau klik kanan perenggan dan ia akan muncul.

Menurut statistik pada 2 Mac, jumlah TVL rangkaian lapisan kedua Bitcoin MerlinChain telah mencecah AS$3 bilion. Antaranya, aset ekologi Bitcoin menyumbang 90.83%, termasuk BTC bernilai AS$1.596 bilion dan aset BRC-20 bernilai AS$404 juta. Bulan lalu, jumlah TVL MerlinChain mencecah AS$1.97 bilion dalam tempoh 14 hari selepas melancarkan aktiviti mempertaruhkan, mengatasi Blast, yang dilancarkan pada November tahun lepas dan juga yang paling terkini dan sama menarik perhatian. Pada 26 Februari, jumlah nilai NFT dalam ekosistem MerlinChain melebihi AS$420 juta, menjadi projek rantaian awam dengan nilai pasaran NFT tertinggi selain Ethereum. Pengenalan Projek MerlinChain ialah sokongan OKX

Perhatian Pertanyaan Berkumpulan (GroupedQueryAttention) ialah kaedah perhatian berbilang pertanyaan dalam model bahasa besar Matlamatnya adalah untuk mencapai kualiti MHA sambil mengekalkan kelajuan MQA. GroupedQueryAttention kumpulan pertanyaan, dan pertanyaan dalam setiap kumpulan berkongsi berat perhatian yang sama, yang membantu mengurangkan kerumitan pengiraan dan meningkatkan kelajuan inferens. Dalam artikel ini, kami akan menerangkan idea GQA dan cara menterjemahkannya ke dalam kod. GQA ada dalam kertas GQA:TrainingGeneralizedMulti-QueryTransformerModelsfromMulti-HeadCheckpoint

Perbezaan dan Analisis Perbandingan Bahasa C dan PHP Bahasa C dan PHP adalah kedua-dua bahasa pengaturcaraan biasa, tetapi mereka mempunyai perbezaan yang jelas dalam banyak aspek. Artikel ini akan menjalankan analisis perbandingan bahasa C dan PHP dan menggambarkan perbezaan antara mereka melalui contoh kod tertentu. 1. Sintaks dan penggunaan: Bahasa C: Bahasa C ialah bahasa pengaturcaraan berorientasikan proses, terutamanya digunakan untuk pengaturcaraan peringkat sistem dan pembangunan terbenam. Sintaks bahasa C agak mudah dan tahap rendah, boleh mengendalikan memori secara langsung, dan cekap dan fleksibel. Bahasa C menekankan kesempurnaan program pengaturcara

Kerumitan masa mengukur masa pelaksanaan algoritma berbanding saiz input. Petua untuk mengurangkan kerumitan masa program C++ termasuk: memilih bekas yang sesuai (seperti vektor, senarai) untuk mengoptimumkan storan dan pengurusan data. Gunakan algoritma yang cekap seperti isihan pantas untuk mengurangkan masa pengiraan. Hapuskan berbilang operasi untuk mengurangkan pengiraan berganda. Gunakan cawangan bersyarat untuk mengelakkan pengiraan yang tidak perlu. Optimumkan carian linear dengan menggunakan algoritma yang lebih pantas seperti carian binari.

Kemajuan zaman telah menjadikan pendapatan ramai orang lebih tinggi dan lebih tinggi, dan telefon bimbit yang biasa mereka gunakan akan ditukar dengan kerap Xiaomi Mi 14 Ultra baru-baru ini dilancarkan oleh Xiaomi mesti biasa kepada pengguna, dan ia boleh menyediakan pengguna dengan lebih Untuk memberikan pengalaman yang selesa dan lancar, telefon mudah alih baru pasti akan menghadapi banyak fungsi yang tidak digunakan Contohnya, bagaimana untuk menggunakan pengembangan imej pintar Xiaomi 14UltraAI? Datang dan lihat tutorial penggunaan di bawah! Bagaimana untuk menggunakan pengembangan imej pintar Xiaomi 14UltraAI? Mula-mula buka Xiaomi 14Ultra, masukkan album foto, pilih gambar yang ingin anda besarkan, dan masukkan pilihan penyuntingan album foto. Klik Crop Rotate, klik Crop, dan klik Smart Expand dalam pilihan yang dipaparkan. Akhir sekali, pilih cara untuk mengembangkan imej mengikut keperluan anda sendiri.
