上三角 triu

Pytorch上三角和下三角的调用与numpy是相同的。

np.triu(np.ones((5,5)),k=0) # k控制对角线开始的位置
Out[25]: 
array([[1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 1.]])

构建一个上三角mask

torch.triu(torch.ones(5,5),diagonal=0)
Out[17]: 
tensor([[1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.]])

dianonal控制上三角的对角线开始位置

torch.triu(torch.ones(5,5),diagonal=1)
Out[20]: 
tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])

下三角 tril

torch.tril(torch.ones(5,5),diagonal=0)
Out[21]: 
tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

torch.tril(torch.ones(5,5),diagonal=1)
Out[22]: 
tensor([[1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐