In PyTorch reduzieren

Patricia Arquette
Freigeben: 2024-11-06 05:58:03
Original
216 Leute haben es durchsucht

Flatten in PyTorch

Kauf mir einen Kaffee☕

*Memos:

  • Mein Beitrag erklärt flatten() und ravel().
  • Mein Beitrag erklärt unflatten().

Flatten() kann null oder mehr Dimensionen entfernen, indem es Dimensionen aus dem 0D- oder mehr D-Tensor von null oder mehr Elementen auswählt und so den 1D- oder mehr D-Tensor von null oder mehr Elementen erhält, wie unten gezeigt:

*Memos:

  • Das erste Argument für die Initialisierung ist start_dim(Optional-Default:1-Type:int).
  • Das 2. Argument für die Initialisierung ist end_dim(Optional-Default:-1-Type:int).
  • Das 1. Argument ist eine Eingabe (Erforderlicher Typ: Tensor von int, float, complex oder bool).
  • Flatten() kann einen 0D-Tensor in einen 1D-Tensor ändern.
  • Flatten() bewirkt nichts für einen 1D-Tensor.
  • Der Unterschied zwischen Flatten() und flatten() ist:
    • Der Standardwert von start_dim für Flatten() ist 1, während der Standardwert von start_dim für flatten() 0 ist.
    • Grundsätzlich wird Flatten() zum Definieren eines Modells verwendet, während flatten() nicht zum Definieren eines Modells verwendet wird.
import torch
from torch import nn

flatten = nn.Flatten()
flatten
# Flatten(start_dim=1, end_dim=-1)

flatten.start_dim
# 1

flatten.end_dim
# -1

my_tensor = torch.tensor(7)

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7])

my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=1)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=0)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=2)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-3, end_dim=2)
flatten = nn.Flatten(start_dim=-3, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-3)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-2)
flatten = nn.Flatten(start_dim=2, end_dim=2)
flatten = nn.Flatten(start_dim=2, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=2)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=0)
flatten = nn.Flatten(start_dim=-3, end_dim=-3)
flatten(input=my_tensor)
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=1)
flatten = nn.Flatten(start_dim=-3, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7], [1], [-8], [3], [-6], [0]])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=1, end_dim=2)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=2)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]])

flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7., 1., -8.], [3., -6., 0.]])

my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]],
                          [[3.+0.j], [-6.+0.j], [0.+0.j]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7.+0.j, 1.+0.j, -8.+0.j],
#         [3.+0.j, -6.+0.j, 0.+0.j]])

my_tensor = torch.tensor([[[True], [False], [True]],
                          [[False], [True], [False]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[True, False, True],
#         [False, True, False]])
Nach dem Login kopieren

Das obige ist der detaillierte Inhalt vonIn PyTorch reduzieren. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Quelle:dev.to
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Neueste Artikel des Autors
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage
Über uns Haftungsausschluss Sitemap
Chinesische PHP-Website:Online-PHP-Schulung für das Gemeinwohl,Helfen Sie PHP-Lernenden, sich schnell weiterzuentwickeln!