<>torch.nn.flatten

torch.nn.flatten是一个类,作用为将连续的几个维度展平成一个tensor(将一些维度合并)

*
参数为合并开始的维度,合并结束的维度(维度就是索引,从 0 开始)

* 开始维度默认为 1。因为其被用在神经网络中,输入为一批数据,第 0
维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。
* 结束维度默认为 -1,也就是一直合并到最后一维

*
默认参数情况
x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten() y = F(x) print(y) print(y.
shape) >>tensor([[1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1.,
1.]]) >>torch.Size([2, 8])
*
输入一个参数情况:该参数为合并开始的维度
x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten(2) y = F(x) print(y) print(y.
shape) >>tensor([[[1., 1., 1., 1.], [1., 1., 1., 1.]], [[1., 1., 1., 1.], [1., 1
., 1., 1.]]]) >>torch.Size([2, 2, 4])
*
输入两个参数情况:第一个参数代表合并开始维度,第二个参数代表合并结束维度(合并范围包含开始维度和结束维度)
x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten(1, 2) y = F(x) print(y) print(y
.shape) >>tensor([[[1., 1.], [1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.],
[1., 1.], [1., 1.]]]) >>torch.Size([2, 4, 2])
<>torch.flatten

作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function
而不是类,其默认开始维度为第 0 维
t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) print(t.shape) >>torch.
Size([2, 2, 2]) print(torch.flatten(t)) >>tensor([1, 2, 3, 4, 5, 6, 7, 8]) print
(torch.flatten(t, 1)) >>tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) print(torch.flatten
(t, 0, 1).shape) >>torch.Size([4, 2])
若输入是 0 维 tensor,则输出的是一维 tensor
t = torch.tensor(1) print("before flatten:") print(t) print(t.shape) >>before
flatten: tensor(1) torch.Size([]) print("\n") print("after flatten:") print(
torch.flatten(t)) print(torch.flatten(t).shape) >>after flatten: tensor([1])
torch.Size([1])

技术
下载桌面版
GitHub
百度网盘(提取码:draw)
Gitee
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:[email protected]
QQ群:766591547
关注微信