PyTorch の 9 つの主要な操作!

PHPz
リリース: 2024-01-06 16:45:47
転載
429 人が閲覧しました

今日は PyTorch について話します。全体的な概念を示す 9 つの最も重要な PyTorch 操作をまとめました。

PyTorch の 9 つの主要な操作!

テンソルの作成と基本操作

PyTorch テンソルは NumPy 配列に似ていますが、GPU アクセラレーションと自動導出機能を備えています。 torch.tensor 関数を使用してテンソルを作成することも、torch.zeros、torch.ones およびその他の関数を使用してテンソルを作成することもできます。これらの関数は、テンソルをより簡単に作成するのに役立ちます。

import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)
ログイン後にコピー

Autograd (Autograd)

torch.autograd モジュールは自動導出メカニズムを提供し、操作の記録と勾配の計算を可能にします。

x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)
ログイン後にコピー

ニューラル ネットワーク層 (nn.Module)

torch.nn.Module は、ニューラル ネットワークを構築するための基本コンポーネントです。線形層 (nn.Linear など) など、さまざまな層を含めることができます。 )、畳み込み層(nn.Conv2d)など。

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()
ログイン後にコピー

オプティマイザー

オプティマイザーは、モデル パラメーターを調整して損失関数を削減するために使用されます。以下は、確率的勾配降下 (SGD) オプティマイザーを使用した例です。

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)
ログイン後にコピー

損失関数

損失関数は、モデルの出力とターゲットの差を測定するために使用されます。たとえば、クロスエントロピー損失は分類問題に適しています。

loss_function = nn.CrossEntropyLoss()
ログイン後にコピー

データのロードと前処理

PyTorch の torch.utils.data モジュールは、データのロードと前処理のための Dataset クラスと DataLoader クラスを提供します。データセット クラスは、さまざまなデータ形式やタスクに合わせてカスタマイズできます。

from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
ログイン後にコピー

モデルの保存とロード

torch.save を使用してモデルの状態辞書を保存し、torch.load を使用してモデルをロードできます。

# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))
ログイン後にコピー

学習率調整

torch.optim.lr_scheduler モジュールは、学習率調整のためのツールを提供します。たとえば、StepLR を使用すると、各エポック後の学習率を下げることができます。

from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
ログイン後にコピー

モデルの評価

モデルのトレーニングが完了したら、モデルのパフォーマンスを評価する必要があります。評価するときは、モデルを評価モード (model.eval()) に切り替え、torch.no_grad() コンテキスト マネージャーを使用して勾配計算を回避する必要があります。

rree

以上がPyTorch の 9 つの主要な操作!の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

関連ラベル:
ソース:51cto.com
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
最新の問題
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート
私たちについて 免責事項 Sitemap
PHP中国語ウェブサイト:福祉オンライン PHP トレーニング,PHP 学習者の迅速な成長を支援します!