Maison > développement back-end > Tutoriel Python > Aplatir dans PyTorch

Aplatir dans PyTorch

Patricia Arquette
Libérer: 2024-11-06 05:58:03
original
361 Les gens l'ont consulté

Flatten in PyTorch

Achetez-moi un café☕

*Mémos :

  • Mon article explique flatten() et ravel().
  • Mon message explique unflatten().

Flatten() peut supprimer zéro ou plusieurs dimensions en sélectionnant les dimensions du tenseur 0D ou plus D de zéro ou plusieurs éléments, obtenant le tenseur 1D ou plus D de zéro ou plusieurs éléments comme indiqué ci-dessous :

*Mémos :

  • Le 1er argument pour l'initialisation est start_dim(Optional-Default:1-Type:int).
  • Le 2ème argument pour l'initialisation est end_dim(Optional-Default:-1-Type:int).
  • Le 1er argument est input(Required-Type : tensor of int, float, complex ou bool).
  • Flatten() peut changer un tenseur 0D en tenseur 1D.
  • Flatten() ne fait rien pour un tenseur 1D.
  • La différence entre Flatten() et flatten() est :
    • La valeur par défaut de start_dim pour Flatten() est 1 tandis que la valeur par défaut de start_dim pour flatten() est 0.
    • Fondamentalement, Flatten() est utilisé pour définir un modèle tandis que flatten() n'est pas utilisé pour définir un modèle.
import torch
from torch import nn

flatten = nn.Flatten()
flatten
# Flatten(start_dim=1, end_dim=-1)

flatten.start_dim
# 1

flatten.end_dim
# -1

my_tensor = torch.tensor(7)

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7])

my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=1)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=0)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=2)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-3, end_dim=2)
flatten = nn.Flatten(start_dim=-3, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-3)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-2)
flatten = nn.Flatten(start_dim=2, end_dim=2)
flatten = nn.Flatten(start_dim=2, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=2)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=0)
flatten = nn.Flatten(start_dim=-3, end_dim=-3)
flatten(input=my_tensor)
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=1)
flatten = nn.Flatten(start_dim=-3, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7], [1], [-8], [3], [-6], [0]])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=1, end_dim=2)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=2)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]])

flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7., 1., -8.], [3., -6., 0.]])

my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]],
                          [[3.+0.j], [-6.+0.j], [0.+0.j]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7.+0.j, 1.+0.j, -8.+0.j],
#         [3.+0.j, -6.+0.j, 0.+0.j]])

my_tensor = torch.tensor([[[True], [False], [True]],
                          [[False], [True], [False]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[True, False, True],
#         [False, True, False]])
Copier après la connexion

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Derniers articles par auteur
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal