squeeze

删除张量维度中值为 1 的项,若原维度为 A*1*B*C*1*D,则删除后维度为 A*B*C*D

x = torch.rand(size=(2, 1, 3))
print(x.shape)
y = x.squeeze()
print(y.shape)

输出结果:

torch.Size([2, 1, 3])
torch.Size([2, 3])

可以指定在某个或某些维度上进行squeeze操作:

x = torch.randn(10, 1, 10, 10, 1)

x.squeeze(0)		# 表示在第0个维度上操作
x.squeeze((1,4))	# 表示在第1和第4维度上进行操作

unsqueeze

在指定位置插入一个值为 1 的维度。

import torch

x = torch.randn(4, 8, 50)
y = torch.unsqueeze(x, dim=1)
print(y.shape)

# torch.Size([4, 1, 8, 50])

transpose

transpose方法用于交换张量的两个维度。这个操作返回一个新的张量。

# 创建一个张量,维度为[3, 4, 5]
tensor = torch.randn(3, 4, 5)
# 交换 Tensor 中 1,2 两个维度
tensor = tensor.transpose(1, 2)

print(tensor.shape) # torch.Size([3, 5, 4])

permute

permute方法用于重新排列Tensor中的所有维度。

import torch

x = torch.randn(16, 5, 10)
x = x.permute(1, 2, 0)	# 重新排列维度

print(x.shape) # torch.Size([5, 10, 16])