PyTorch搜索Tensor指定维度的前K大个(K小个)元素--------(torch.topk)命令参数详解及举例
torch.topk语法torch.topk(input, k, dim=None, largest=True, sorted=True, *, out = None)作用返回输入tensorinput中,在给定的维度dim上k个最大的元素。如果dim没有给定,那么选择输入input的最后一维。如果largest = False,那么返回k个最小的元素。返回一个namedtuple类型的元组(va
torch.topk
语法
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out = None)
作用
返回输入tensorinput中,在给定的维度dim上k个最大的元素。
如果dim没有给定,那么选择输入input的最后一维。
如果largest = False,那么返回k个最小的元素。
返回一个namedtuple类型的元组(values, indices),其中indices是指元素在原数组中的索引。
sorted = True, 则返回的k个元素是有序的。
Parameters
-
input (Tensor) – the input tensor
输入的张量 -
k (int) – the k in “top-k”
返回的k的值 -
dim(int, optional) – the dimension to sort along
指定的排序的维度 ,如果dim没有给定,那么选择输入input的最后一维。dim若为-1,文档未说明,但是根据实操效果,应该也是对最后一维进行search。
如shape为Batch_size x p x q,返回结果为Batch_size x p x k。 -
largest(bool, optional) – controls whether to return largest or smallest elements
True返回最大值,False返回最小值。 -
sorted(bool, optional) – controls whether to return the elements in sorted order
控制返回的元素是否排序。
例子
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1., 2., 3., 4., 5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
更多推荐


所有评论(0)