目次
#1. シングル カード モデルでモデル構造とパラメータを保存した後の読み込みの問題
モデルを保存すると、モデル構造定義ファイルへのパスが記録されます。 . をロードすると、パスに従って解析されてパラメータがロードされますが、モデル定義ファイルのパスが変更されると、torch.load(path) を使用するとエラーが報告されます。
ホームページ バックエンド開発 Python チュートリアル pytorch モデルの保存と読み込みにおけるいくつかの問題の実践的な記録

pytorch モデルの保存と読み込みにおけるいくつかの問題の実践的な記録

Nov 03, 2022 pm 05:33 PM
python

この記事は、Python に関する関連知識を提供するもので、主に pytorch モデルの保存と読み込みに関するいくつかの問題の実践的な記録を紹介します。一緒に見てみましょう。皆様のお役に立てれば幸いです。ヘルプ。

#[関連する推奨事項:

Python3 ビデオ チュートリアル ]

1. torch でモデルを保存およびロードする方法

1. モデル パラメーターとモデル構造の保存と読み込み

torch.save(model,path)
torch.load(path)
ログイン後にコピー

2. モデル パラメーターの保存と読み込みのみ - この方法は安全ですが、少し面倒です

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)
ログイン後にコピー

2. モデル保存の問題

#1. シングル カード モデルでモデル構造とパラメータを保存した後の読み込みの問題

モデルを保存すると、モデル構造定義ファイルへのパスが記録されます。 . をロードすると、パスに従って解析されてパラメータがロードされますが、モデル定義ファイルのパスが変更されると、torch.load(path) を使用するとエラーが報告されます。

#モデルフォルダーをmodelsに変更した後、再度ロードするとエラーが報告されます。

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)
ログイン後にコピー

完全なモデル構造とパラメーターを保存するこの方法では、モデル定義ファイルのパス

を変更しないように注意してください。

2. シングルカード トレーニング モデルをマルチカード マシンに保存した後、シングルカード マシンにロードするとエラーが報告されます。マルチカード マシンでは 0 から開始し、モデルは n>= 1 でのグラフィックス カード トレーニングが保存された後、コピーがシングルカード マシンにロードされます

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)
ログイン後にコピー

cuda デバイスの不一致の問題が発生します - 保存したモデル コード セグメント ウィジェット タイプ cuda1 を使用する場合、torch.load() を使用してそれを開くと、デフォルトで cuda1 が検索され、ロードされます。モデルをデバイスに接続します。現時点では、map_location を直接使用して問題を解決し、モデルを CPU にロードできます。

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
ログイン後にコピー

3. マルチ GPU トレーニング モデルのモデル構造とパラメーターを保存してロードした後に発生する問題

複数の GPU を使用してモデルを同時にトレーニングする場合、モデル構造とパラメータは一緒に保存するか、別々に保存します。モデル パラメータは、単一のカードでロードするときに問題が発生します

#a. モデル構造とパラメータを一緒に保存し、ロード時に使用します

#
torch.distributed.init_process_group(backend='nccl')
ログイン後にコピー

上記のマルチプロセスメソッドなので、ロード時に宣言しないとエラーが報告されます。

#b. モデルパラメータを個別に保存する

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)
ログイン後にコピー

同じ問題が発生しますが、ここでの問題は、パラメータ辞書のキーがモデルで定義されているキーと異なることです

その理由は、マルチ GPU トレーニングで分散トレーニングを使用すると、モデルがパッケージ化されるためです。コードは次のとおりです:

model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)
ログイン後にコピー

パッケージ化前のモデル構造:

パッケージ化モデル

外層にはさらに DistributedDataParallel とモジュールがあるため、ロード時に重みが表示されますシングルカード環境でのモデルの重み キーが一致していません。

3. モデルを保存およびロードする正しい方法

    if gpu_count > 1:
        torch.save(model.module.state_dict(),save_path)
    else:
        torch.save(model.state_dict(),save_path)
    model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict)
ログイン後にコピー
これはより良いパラダイムであり、ロード時にエラーは発生しません。

【関連する推奨事項:

Python3 ビデオ チュートリアル

]

以上がpytorch モデルの保存と読み込みにおけるいくつかの問題の実践的な記録の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。

ホットAIツール

Undresser.AI Undress

Undresser.AI Undress

リアルなヌード写真を作成する AI 搭載アプリ

AI Clothes Remover

AI Clothes Remover

写真から衣服を削除するオンライン AI ツール。

Undress AI Tool

Undress AI Tool

脱衣画像を無料で

Clothoff.io

Clothoff.io

AI衣類リムーバー

AI Hentai Generator

AI Hentai Generator

AIヘンタイを無料で生成します。

ホットツール

メモ帳++7.3.1

メモ帳++7.3.1

使いやすく無料のコードエディター

SublimeText3 中国語版

SublimeText3 中国語版

中国語版、とても使いやすい

ゼンドスタジオ 13.0.1

ゼンドスタジオ 13.0.1

強力な PHP 統合開発環境

ドリームウィーバー CS6

ドリームウィーバー CS6

ビジュアル Web 開発ツール

SublimeText3 Mac版

SublimeText3 Mac版

神レベルのコード編集ソフト(SublimeText3)

Python:ゲーム、GUIなど Python:ゲーム、GUIなど Apr 13, 2025 am 12:14 AM

PythonはゲームとGUI開発に優れています。 1)ゲーム開発は、2Dゲームの作成に適した図面、オーディオ、その他の機能を提供し、Pygameを使用します。 2)GUI開発は、TKINTERまたはPYQTを選択できます。 TKINTERはシンプルで使いやすく、PYQTは豊富な機能を備えており、専門能力開発に適しています。

PHPとPython:2つの一般的なプログラミング言語を比較します PHPとPython:2つの一般的なプログラミング言語を比較します Apr 14, 2025 am 12:13 AM

PHPとPythonにはそれぞれ独自の利点があり、プロジェクトの要件に従って選択します。 1.PHPは、特にWebサイトの迅速な開発とメンテナンスに適しています。 2。Pythonは、データサイエンス、機械学習、人工知能に適しており、簡潔な構文を備えており、初心者に適しています。

Debian Readdirが他のツールと統合する方法 Debian Readdirが他のツールと統合する方法 Apr 13, 2025 am 09:42 AM

DebianシステムのReadDir関数は、ディレクトリコンテンツの読み取りに使用されるシステムコールであり、Cプログラミングでよく使用されます。この記事では、ReadDirを他のツールと統合して機能を強化する方法について説明します。方法1:C言語プログラムを最初にパイプラインと組み合わせて、cプログラムを作成してreaddir関数を呼び出して結果をinclude#include#include inctargc、char*argv []){dir*dir; structdireant*entry; if(argc!= 2){(argc!= 2){

Pythonと時間:勉強時間を最大限に活用する Pythonと時間:勉強時間を最大限に活用する Apr 14, 2025 am 12:02 AM

限られた時間でPythonの学習効率を最大化するには、PythonのDateTime、時間、およびスケジュールモジュールを使用できます。 1. DateTimeモジュールは、学習時間を記録および計画するために使用されます。 2。時間モジュールは、勉強と休息の時間を設定するのに役立ちます。 3.スケジュールモジュールは、毎週の学習タスクを自動的に配置します。

Nginx SSL証明書更新Debianチュートリアル Nginx SSL証明書更新Debianチュートリアル Apr 13, 2025 am 07:21 AM

この記事では、DebianシステムでNGINXSSL証明書を更新する方法について説明します。ステップ1:最初にCERTBOTをインストールして、システムがCERTBOTおよびPython3-Certbot-Nginxパッケージがインストールされていることを確認してください。インストールされていない場合は、次のコマンドを実行してください。sudoapt-getupdatesudoapt-getinstolcallcertbotthon3-certbot-nginxステップ2:certbotコマンドを取得して構成してlet'sencrypt証明書を取得し、let'sencryptコマンドを取得し、nginx:sudocertbot - nginxを構成します。

DebianのGitlabのプラグイン開発ガイド DebianのGitlabのプラグイン開発ガイド Apr 13, 2025 am 08:24 AM

DebianでGitLabプラグインを開発するには、特定の手順と知識が必要です。このプロセスを始めるのに役立つ基本的なガイドを以下に示します。最初にgitlabをインストールすると、debianシステムにgitlabをインストールする必要があります。 GitLabの公式インストールマニュアルを参照できます。 API統合を実行する前に、APIアクセストークンを取得すると、GitLabのAPIアクセストークンを最初に取得する必要があります。 gitlabダッシュボードを開き、ユーザー設定で「アクセストーケン」オプションを見つけ、新しいアクセストークンを生成します。生成されます

debian opensslでHTTPSサーバーを構成する方法 debian opensslでHTTPSサーバーを構成する方法 Apr 13, 2025 am 11:03 AM

DebianシステムでHTTPSサーバーの構成には、必要なソフトウェアのインストール、SSL証明書の生成、SSL証明書を使用するWebサーバー(ApacheやNginxなど)の構成など、いくつかのステップが含まれます。 Apachewebサーバーを使用していると仮定して、基本的なガイドです。 1.最初に必要なソフトウェアをインストールし、システムが最新であることを確認し、ApacheとOpenSSL:sudoaptupdatesudoaptupgraysudoaptinstaをインストールしてください

Apacheとは何ですか Apacheとは何ですか Apr 13, 2025 pm 12:06 PM

アパッチはインターネットの背後にあるヒーローです。それはWebサーバーであるだけでなく、膨大なトラフィックをサポートし、動的なコンテンツを提供する強力なプラットフォームでもあります。モジュラー設計を通じて非常に高い柔軟性を提供し、必要に応じてさまざまな機能を拡張できるようにします。ただし、モジュール性は、慎重な管理を必要とする構成とパフォーマンスの課題も提示します。 Apacheは、高度にカスタマイズ可能で複雑なニーズを満たす必要があるサーバーシナリオに適しています。

See all articles