匹配网络 Matching Network
匹配网络 Matching Network匹配网络其实就是引入注意力机制,通过对 embedding 后的特征计算注意力,利用注意力得分进行分析:首先也是对支持集和查询集进行 embedding,然后用查询集样本对每个支持集样本计算注意力:a(x^,xi)=ec(f(x^),g(xi))/∑j=1kec(f(x^),g(xj))a\left(\hat{x}, x_{i}\right)=e^{c\l
匹配网络 Matching Network
匹配网络其实就是引入注意力机制,通过对 embedding 后的特征计算注意力,利用注意力得分进行分析:
首先也是对支持集和查询集进行 embedding,然后用查询集样本对每个支持集样本计算注意力:
a ( x ^ , x i ) = e c ( f ( x ^ ) , g ( x i ) ) / ∑ j = 1 k e c ( f ( x ^ ) , g ( x j ) ) a\left(\hat{x}, x_{i}\right)=e^{c\left(f(\hat{x}), g\left(x_{i}\right)\right)} / \sum_{j=1}^{k} e^{c\left(f(\hat{x}), g\left(x_{j}\right)\right)} a(x^,xi)=ec(f(x^),g(xi))/j=1∑kec(f(x^),g(xj))
其中:
- f 和 g是我们选择的合适的神经网络,一般 f = g,用于输入的 embedding
- x i x_i xi 是支持集, x ^ \hat x x^ 是查询集
- c 是余弦距离
计算了注意力之后,就分析查询集的样本:
P ( y ^ ∣ x ^ , S ) = ∑ i = 1 k a ( x ^ , x i ) y i P(\hat{y} \mid \hat{x}, S)=\sum_{i=1}^{k} a\left(\hat{x}, x_{i}\right) y_{i} P(y^∣x^,S)=i=1∑ka(x^,xi)yi
其中:
- y i y_i yi 是每个类别的标签,其实就是把每个类别根据注意力得分进行线性加权
- P 是计算出对应类别的概率
最后的训练目标为:
θ = arg max θ E L ∼ T [ E S ∼ L , B ∼ L [ ∑ ( x , y ) ∈ B log P θ ( y ∣ x , S ) ] ] \theta=\arg \max _{\theta} E_{L \sim T}\left[E_{S \sim L, B \sim L}\left[\sum_{(x, y) \in B} \log P_{\theta}(y \mid x, S)\right]\right] θ=argθmaxEL∼T⎣⎡ES∼L,B∼L⎣⎡(x,y)∈B∑logPθ(y∣x,S)⎦⎤⎦⎤
个人总结:
总的来说,匹配网络把整个分析的过程都简化到注意力计算过程中,如果某个类别的注意力得分比较高,其实就意味着测试样本属于这个类别的可能性比较大,所以模型的训练重点就回到最初的 embedding 中。
Vinyals O, Blundell C, Lillicrap T, et al. Matching networks for one shot learning[J]. Advances in neural information processing systems, NIPS 2016, 29: 3630-3638.
元学习系列(四):Matching Network(匹配网络)
更多推荐


所有评论(0)