在 PyTorch 中展平

Patricia Arquette
發布: 2024-11-06 05:58:03
原創
216 人瀏覽過

Flatten in PyTorch

請我喝杯咖啡☕

*備忘錄:

  • 我的貼文解釋了 flatten() 和 ravel()。
  • 我的貼文解釋了 unflatten()。

Flatten() 可以透過從零個或多個元素的0D 或多個D 張量中選擇維度來移除零個或多個維度,得到零個或多個元素的1D 或多個D 張量,如下圖所示:

*備忘錄:

  • 初始化的第一個參數是 start_dim(Optional-Default:1-Type:int)。
  • 初始化的第二個參數是 end_dim(可選-預設:-1-型別:int)。
  • 第一個參數是輸入(必要類型:int、float、complex 或 bool 的張量)。
  • Flatten() 可以將 0D 張量改為 1D 張量。
  • Flatten() 對於一維張量沒有任何作用。
  • Flatten() 和 flatten() 的差別是:
    • Flatten() 的 start_dim 預設值為 1,而 flatten() 的 start_dim 預設值為 0。
    • 基本上,Flatten() 用來定義模型,而 flatten() 不用於定義模型。
import torch
from torch import nn

flatten = nn.Flatten()
flatten
# Flatten(start_dim=1, end_dim=-1)

flatten.start_dim
# 1

flatten.end_dim
# -1

my_tensor = torch.tensor(7)

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7])

my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=1)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=0)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=2)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-3, end_dim=2)
flatten = nn.Flatten(start_dim=-3, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-3)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-2)
flatten = nn.Flatten(start_dim=2, end_dim=2)
flatten = nn.Flatten(start_dim=2, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=2)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=0)
flatten = nn.Flatten(start_dim=-3, end_dim=-3)
flatten(input=my_tensor)
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=1)
flatten = nn.Flatten(start_dim=-3, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7], [1], [-8], [3], [-6], [0]])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=1, end_dim=2)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=2)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]])

flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7., 1., -8.], [3., -6., 0.]])

my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]],
                          [[3.+0.j], [-6.+0.j], [0.+0.j]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7.+0.j, 1.+0.j, -8.+0.j],
#         [3.+0.j, -6.+0.j, 0.+0.j]])

my_tensor = torch.tensor([[[True], [False], [True]],
                          [[False], [True], [False]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[True, False, True],
#         [False, True, False]])
登入後複製

以上是在 PyTorch 中展平的詳細內容。更多資訊請關注PHP中文網其他相關文章!

來源:dev.to
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
作者最新文章
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板
關於我們 免責聲明 Sitemap
PHP中文網:公益線上PHP培訓,幫助PHP學習者快速成長!