この記事は、Python に関する関連知識を提供するもので、主に pytorch モデルの保存と読み込みに関するいくつかの問題の実践的な記録を紹介します。一緒に見てみましょう。皆様のお役に立てれば幸いです。ヘルプ。
#[関連する推奨事項:Python3 ビデオ チュートリアル ]
1. torch でモデルを保存およびロードする方法1. モデル パラメーターとモデル構造の保存と読み込みtorch.save(model,path) torch.load(path)
torch.save(model.state_dict(),path) model_state_dic = torch.load(path) model.load_state_dic(model_state_dic)
#モデルフォルダーをmodelsに変更した後、再度ロードするとエラーが報告されます。
import torch from model.TextRNN import TextRNN load_model = torch.load('experiment_model_save/textRNN.bin') print('load_model',load_model)
完全なモデル構造とパラメーターを保存するこの方法では、モデル定義ファイルのパス
を変更しないように注意してください。2. シングルカード トレーニング モデルをマルチカード マシンに保存した後、シングルカード マシンにロードするとエラーが報告されます。マルチカード マシンでは 0 から開始し、モデルは n>= 1 でのグラフィックス カード トレーニングが保存された後、コピーがシングルカード マシンにロードされます
import torch from model.TextRNN import TextRNN load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin') print('load_model',load_model)
cuda デバイスの不一致の問題が発生します - 保存したモデル コード セグメント ウィジェット タイプ cuda1 を使用する場合、torch.load() を使用してそれを開くと、デフォルトで cuda1 が検索され、ロードされます。モデルをデバイスに接続します。現時点では、map_location を直接使用して問題を解決し、モデルを CPU にロードできます。
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
3. マルチ GPU トレーニング モデルのモデル構造とパラメーターを保存してロードした後に発生する問題
複数の GPU を使用してモデルを同時にトレーニングする場合、モデル構造とパラメータは一緒に保存するか、別々に保存します。モデル パラメータは、単一のカードでロードするときに問題が発生します#a. モデル構造とパラメータを一緒に保存し、ロード時に使用します
torch.distributed.init_process_group(backend='nccl')
上記のマルチプロセスメソッドなので、ロード時に宣言しないとエラーが報告されます。
#b. モデルパラメータを個別に保存するmodel = Transformer(num_encoder_layers=6,num_decoder_layers=6) state_dict = torch.load('train_model/clip/experiment.pt') model.load_state_dict(state_dict)
model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin') print(model) model.cuda(args.local_rank) 。。。。。。 model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True) print('model',model)
外層にはさらに DistributedDataParallel とモジュールがあるため、ロード時に重みが表示されますシングルカード環境でのモデルの重み キーが一致していません。
3. モデルを保存およびロードする正しい方法if gpu_count > 1: torch.save(model.module.state_dict(),save_path) else: torch.save(model.state_dict(),save_path) model = Transformer(num_encoder_layers=6,num_decoder_layers=6) state_dict = torch.load(save_path) model.load_state_dict(state_dict)
【関連する推奨事項:
Python3 ビデオ チュートリアル]
以上がpytorch モデルの保存と読み込みにおけるいくつかの問題の実践的な記録の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。