Pytorch统计二维张量每一行的非零个数
假设一个2维的pytorch张量a,有m行,n列,想求得a的每一行非零的元素个数,最终得到m行1列的张量
·
一、期望
输入
tensor([[0, 0, 0],
[1, 1, 1],
[0, 0, 1]])
输出
tensor([[0],
[3],
[1]])
二、方法
代码
import torch
# input
a = torch.tensor([[0.0, 0.0, 0.0],
[1.0, 3.0, 10.5],
[0.0, 0.0, 2.0]])
sign_a = torch.sign(a).int()
print(sign_a)
# count
non_zero_a = torch.count_nonzero(sign_a, dim=1).reshape(-1, 1)
# output
print(non_zero_a)
输出
tensor([[0, 0, 0],
[1, 1, 1],
[0, 0, 1]], dtype=torch.int32)
tensor([[0],
[3],
[1]])
更多推荐

所有评论(0)