Home > Technology peripherals > AI > body text

Nine key operations of PyTorch!

PHPz
Release: 2024-01-06 16:45:47
forward
429 people have browsed it

Today we will talk about PyTorch. I have summarized the nine most important PyTorch operations, which will give you an overall concept.

Nine key operations of PyTorch!

Tensor creation and basic operations

PyTorch tensors are similar to NumPy arrays, but they have GPU acceleration and automatic derivation functions. We can use the torch.tensor function to create tensors, or we can use torch.zeros, torch.ones and other functions to create tensors. These functions can help us create tensors more conveniently.

import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)
Copy after login

Autograd (Autograd)

The torch.autograd module provides an automatic derivation mechanism, allowing recording operations and calculating gradients.

x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)
Copy after login

Neural network layer (nn.Module)

torch.nn.Module is the basic component for building a neural network. It can include various layers, such as linear layer (nn.Linear), Convolutional layer (nn.Conv2d) etc.

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()
Copy after login

Optimizer

The optimizer is used to adjust model parameters to reduce the loss function. Below is an example using the Stochastic Gradient Descent (SGD) optimizer.

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)
Copy after login

Loss Function

The loss function is used to measure the difference between the model output and the target. For example, cross-entropy loss is suitable for classification problems.

loss_function = nn.CrossEntropyLoss()
Copy after login

Data loading and preprocessing

The torch.utils.data module of PyTorch provides the Dataset and DataLoader classes for loading and preprocessing data. Dataset classes can be customized to suit different data formats and tasks.

from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
Copy after login

Model saving and loading

You can use torch.save to save the model's state dictionary, and use torch.load to load the model.

# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))
Copy after login

Learning rate adjustment

The torch.optim.lr_scheduler module provides tools for learning rate adjustment. For example, StepLR can be used to reduce the learning rate after each epoch.

from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
Copy after login

Model evaluation

After the model training is completed, the model performance needs to be evaluated. When evaluating, you need to switch the model to evaluation mode (model.eval()) and use the torch.no_grad() context manager to avoid gradient calculations.

model.eval()with torch.no_grad():# 运行模型并计算性能指标
Copy after login

The above is the detailed content of Nine key operations of PyTorch!. For more information, please follow other related articles on the PHP Chinese website!

Related labels:
source:51cto.com
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template
About us Disclaimer Sitemap
php.cn:Public welfare online PHP training,Help PHP learners grow quickly!