2. torch.chunk(tensor, chunks, dim)

说明:在给定的维度上讲张量进行分块。

参数

  • tensor(Tensor) -- 待分块的输入张量
  • chunks(int) -- 分块的个数
  • dim(int) -- 维度,沿着此维度进行分块
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 1.0103,  2.3358, -1.9236],
        [-0.3890,  0.6594,  0.6664],
        [ 0.5240, -1.4193,  0.1681]])
>>> torch.chunk(x, 3, dim=0)
(tensor([[ 1.0103,  2.3358, -1.9236]]), tensor([[-0.3890,  0.6594,  0.6664]]), tensor([[ 0.5240, -1.4193,  0.1681]]))
>>> torch.chunk(x, 3, dim=1)
(tensor([[ 1.0103],
        [-0.3890],
        [ 0.5240]]), tensor([[ 2.3358],
        [ 0.6594],
        [-1.4193]]), tensor([[-1.9236],
        [ 0.6664],
        [ 0.1681]]))
>>> torch.chunk(x, 2, dim=1)
(tensor([[ 1.0103,  2.3358],
        [-0.3890,  0.6594],
        [ 0.5240, -1.4193]]), tensor([[-1.9236],
        [ 0.6664],
        [ 0.1681]]))

 

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐