Rumah > pembangunan bahagian belakang > Tutorial Python > Rekod praktikal beberapa masalah dalam menyimpan dan memuatkan model pytorch

Rekod praktikal beberapa masalah dalam menyimpan dan memuatkan model pytorch

WBOY
Lepaskan: 2022-11-03 20:39:45
ke hadapan
2719 orang telah melayarinya

Artikel ini membawakan anda pengetahuan yang berkaitan tentang Python terutamanya memperkenalkan rekod praktikal tentang beberapa masalah dalam menyimpan dan memuatkan model pytorch. Saya harap ia dapat membantu semua orang. membantu.

[Cadangan berkaitan: Tutorial video Python3 ]

1. Cara menyimpan dan memuatkan model dalam obor

1. Simpan dan muatkan parameter model dan struktur model

torch.save(model,path)
torch.load(path)
Salin selepas log masuk

2. Hanya simpan dan muatkan parameter model - kaedah ini lebih selamat, tetapi lebih menyusahkan

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)
Salin selepas log masuk

2. Masalah dalam menyimpan dan memuatkan model dalam obor

1 Masalah dalam memuatkan struktur dan parameter model selepas menyimpannya dalam model kad tunggal

Struktur model akan ditakrifkan apabila model disimpan. laluan fail direkodkan, dan apabila memuatkan, ia akan dihuraikan mengikut laluan dan parameter akan dimuatkan selepas laluan fail definisi model diubah suai, ralat akan dilaporkan apabila menggunakan torch.load(path).

Selepas menukar folder model kepada model, ralat akan dilaporkan apabila memuatkan semula.

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)
Salin selepas log masuk

Dengan cara ini untuk menyimpan struktur dan parameter model lengkap, pastikan anda tidak menukar laluan fail definisi model .

2. Selepas menyimpan model latihan kad tunggal pada mesin berbilang kad, ralat akan dilaporkan semasa memuatkannya pada mesin kad tunggal

Bermula dari 0 pada berbilang -mesin kad dengan berbilang kad grafik, model kini n>= Selepas menyimpan latihan kad grafik pada 1 dan memuatkan salinan pada mesin kad tunggal

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)
Salin selepas log masuk

akan ada masalah ketidakpadanan peranti cuda - segmen kod modul yang anda simpan adalah kecil Jenis komponen menggunakan cuda1, jadi apabila ia dibuka menggunakan torch.load(), ia akan mencari cuda1 secara lalai, dan kemudian memuatkan model. kepada peranti. Pada masa ini, anda boleh terus menggunakan map_location untuk menyelesaikan masalah dan memuatkan model ke CPU.

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
Salin selepas log masuk

3 Masalah yang berlaku selepas menyimpan struktur model dan parameter model latihan berbilang GPU dan kemudian memuatkannya

Apabila menggunakan berbilang GPU untuk melatih model pada masa yang sama. , sama ada struktur model dan parameter disimpan bersama atau Masalah akan berlaku jika anda menyimpan parameter model secara berasingan dan kemudian memuatkannya di bawah satu kad

a Simpan struktur model dan parameter bersama-sama dan kemudian memuatkan model

torch.distributed.init_process_group(backend='nccl')
Salin selepas log masuk

Kaedah berbilang proses di atas digunakan semasa latihan, jadi anda mesti mengisytiharkannya semasa memuatkan, jika tidak ralat akan dilaporkan.

b. Menyimpan parameter model secara berasingan

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)
Salin selepas log masuk

juga akan menyebabkan masalah, tetapi masalahnya di sini ialah kunci kamus parameter berbeza daripada kunci yang ditakrifkan oleh model

Sebabnya di bawah latihan multi-GPU, apabila menggunakan latihan yang diedarkan, model tersebut akan dibungkus seperti berikut:

model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)
Salin selepas log masuk

Struktur model sebelum pembungkusan:

Model berpakej

mempunyai DistributedDataParallel dan modul di lapisan luar, jadi berat model adalah dimuatkan dalam persekitaran kad tunggal Apabila kekunci berat tidak konsisten.

3. Kaedah menyimpan dan memuatkan model yang betul

    if gpu_count > 1:
        torch.save(model.module.state_dict(),save_path)
    else:
        torch.save(model.state_dict(),save_path)
    model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict)
Salin selepas log masuk

Ini adalah paradigma yang lebih baik, dan tidak akan ada ralat dalam memuatkan.

[Cadangan berkaitan: Tutorial video Python3]

Atas ialah kandungan terperinci Rekod praktikal beberapa masalah dalam menyimpan dan memuatkan model pytorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Label berkaitan:
sumber:jb51.net
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan