ホームページ > バックエンド開発 > Python チュートリアル > pytorch モデルの保存と読み込みにおけるいくつかの問題の実践的な記録

pytorch モデルの保存と読み込みにおけるいくつかの問題の実践的な記録

WBOY
リリース: 2022-11-03 20:39:45
転載
2629 人が閲覧しました

この記事は、Python に関する関連知識を提供するもので、主に pytorch モデルの保存と読み込みに関するいくつかの問題の実践的な記録を紹介します。一緒に見てみましょう。皆様のお役に立てれば幸いです。ヘルプ。

#[関連する推奨事項:

Python3 ビデオ チュートリアル ]

1. torch でモデルを保存およびロードする方法

1. モデル パラメーターとモデル構造の保存と読み込み

torch.save(model,path)
torch.load(path)
ログイン後にコピー

2. モデル パラメーターの保存と読み込みのみ - この方法は安全ですが、少し面倒です

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)
ログイン後にコピー

2. モデル保存の問題

#1. シングル カード モデルでモデル構造とパラメータを保存した後の読み込みの問題

モデルを保存すると、モデル構造定義ファイルへのパスが記録されます。 . をロードすると、パスに従って解析されてパラメータがロードされますが、モデル定義ファイルのパスが変更されると、torch.load(path) を使用するとエラーが報告されます。

#モデルフォルダーをmodelsに変更した後、再度ロードするとエラーが報告されます。

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)
ログイン後にコピー

完全なモデル構造とパラメーターを保存するこの方法では、モデル定義ファイルのパス

を変更しないように注意してください。

2. シングルカード トレーニング モデルをマルチカード マシンに保存した後、シングルカード マシンにロードするとエラーが報告されます。マルチカード マシンでは 0 から開始し、モデルは n>= 1 でのグラフィックス カード トレーニングが保存された後、コピーがシングルカード マシンにロードされます

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)
ログイン後にコピー

cuda デバイスの不一致の問題が発生します - 保存したモデル コード セグメント ウィジェット タイプ cuda1 を使用する場合、torch.load() を使用してそれを開くと、デフォルトで cuda1 が検索され、ロードされます。モデルをデバイスに接続します。現時点では、map_location を直接使用して問題を解決し、モデルを CPU にロードできます。

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
ログイン後にコピー

3. マルチ GPU トレーニング モデルのモデル構造とパラメーターを保存してロードした後に発生する問題

複数の GPU を使用してモデルを同時にトレーニングする場合、モデル構造とパラメータは一緒に保存するか、別々に保存します。モデル パラメータは、単一のカードでロードするときに問題が発生します

#a. モデル構造とパラメータを一緒に保存し、ロード時に使用します

#
torch.distributed.init_process_group(backend='nccl')
ログイン後にコピー

上記のマルチプロセスメソッドなので、ロード時に宣言しないとエラーが報告されます。

#b. モデルパラメータを個別に保存する

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)
ログイン後にコピー

同じ問題が発生しますが、ここでの問題は、パラメータ辞書のキーがモデルで定義されているキーと異なることです

その理由は、マルチ GPU トレーニングで分散トレーニングを使用すると、モデルがパッケージ化されるためです。コードは次のとおりです:

model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)
ログイン後にコピー

パッケージ化前のモデル構造:

パッケージ化モデル

外層にはさらに DistributedDataParallel とモジュールがあるため、ロード時に重みが表示されますシングルカード環境でのモデルの重み キーが一致していません。

3. モデルを保存およびロードする正しい方法

    if gpu_count > 1:
        torch.save(model.module.state_dict(),save_path)
    else:
        torch.save(model.state_dict(),save_path)
    model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict)
ログイン後にコピー
これはより良いパラダイムであり、ロード時にエラーは発生しません。

【関連する推奨事項:

Python3 ビデオ チュートリアル

]

以上がpytorch モデルの保存と読み込みにおけるいくつかの問題の実践的な記録の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

関連ラベル:
ソース:jb51.net
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
最新の問題
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート