不用训练,直接用在ImageNet上训练的RestNet网络就可以做一个简单的以图搜图功能

1、RestNet网络

在这里插入图片描述
上面是resnet网络的结构图,ImageNet是一个有1000类的数据集,我们可以把在该数据集上训练过的Resnet网络当成特征提取器,用来提取图片的特征,然后比对特征的欧式距离,判定两种图片的相似性。
average pool的输出是512x1x1,将其reshape为512x1当作该图片的特征。

2、生成RestNet网络

arch_name = “resnet18"
pretrained = True,加载预训练模型

self.retriever_net = torchvision.models.__dict__[arch_name](pretrained = pretrained)

3、提取average pool层的输出

feature_layer_name = ‘avgpool’
feature_index_in_module = 0
register_forward_hook函数,forward时负责保存某个模块的输出

self.feature_layer_name = feature_layer_name
self.feature_index_in_module = feature_index_in_module   self.retriever_net._modules.get(self.feature_layer_name).register_forward_hook(self.hook_feature)

4、比对两个特征

使用欧式距离

dist = F.pairwise_distance(contrast_features, retrieved_features, p=2)

5、贴一张我简陋的gui界面,要捂脸啦

在这里插入图片描述左边的load是加载一张对比图片,右边的load是加载一个文件夹,点Retriever,开始在加载的文件夹中查找和对比图片top1相似的图片,然后显示出来。

两天里面挤时间写的,所以功能很简单,gui很简陋,自己倒是觉得挺有意思的,给大家抛砖引玉吧,全部代码在这里。
image_retrieval_with_gui

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐