HF-NET使用记录

 2023-09-05 阅读 96 评论 0

摘要:HF-Net是用来重定位的,也就是来一个查询帧query,去数据库中查找与哪一个图像匹配。 项目地址:https://github.com/ethz-asl/hfnet 环境配置 首先作者提供了一个训练好的模型可供下载下载很慢,我提供个百度云,提取码: qiis下载下来后放到sa

HF-Net是用来重定位的,也就是来一个查询帧query,去数据库中查找与哪一个图像匹配。

项目地址:https://github.com/ethz-asl/hfnet

环境配置

  • 首先作者提供了一个训练好的模型可供下载
  • 下载很慢,我提供个百度云,提取码: qiis
  • 下载下来后放到saved_models目录下即可
  • 要求的tensorflow版本是1.12
  • 而且cudnn必须是7.1.4,这是运行是报错时提示的
  • CUDA的版本9.0就好

Demo脚本


# coding: utf-8# In[ ]:import cv2
import numpy as np
from pathlib import Path
import picklefrom hfnet.settings import EXPER_PATH
from notebooks.utils import plot_images, plot_matches, add_frameimport tensorflow as tf
from tensorflow.python.saved_model import tag_constants
import tf.contrib.resampler  # import C++ op  !!这一行不能删除# # Load query (night) and database (day) images# In[2]:query_idx = 1  # also try with 2 and3
read_image = lambda n: cv2.imread('./doc/demo/' + n)[:, :, ::-1]
image_query = read_image(f'query{query_idx}.jpg')
images_db = [read_image(f'db{i}.jpg') for i in range(1, 5)]plot_images([image_query] + images_db, dpi=50)# # Create HF-Net model for inference# In[ ]:class HFNet:def __init__(self, model_path, outputs):self.session = tf.Session()self.image_ph = tf.placeholder(tf.float32, shape=(None, None, 3))net_input = tf.image.rgb_to_grayscale(self.image_ph[None])tf.saved_model.loader.load(self.session, [tag_constants.SERVING], str(model_path),clear_devices=True,input_map={'image:0': net_input})graph = tf.get_default_graph()self.outputs = {n: graph.get_tensor_by_name(n+':0')[0] for n in outputs}self.nms_radius_op = graph.get_tensor_by_name('pred/simple_nms/radius:0')self.num_keypoints_op = graph.get_tensor_by_name('pred/top_k_keypoints/k:0')def inference(self, image, nms_radius=4, num_keypoints=1000):inputs = {self.image_ph: image[..., ::-1].astype(np.float),self.nms_radius_op: nms_radius,self.num_keypoints_op: num_keypoints,}return self.session.run(self.outputs, feed_dict=inputs)model_path = Path(EXPER_PATH, 'saved_models/hfnet')
outputs = ['global_descriptor', 'keypoints', 'local_descriptors']
hfnet = HFNet(model_path, outputs)# # Compute global descriptors and local features for query and database
# In[4]:db = [hfnet.inference(i) for i in images_db]
global_index = np.stack([d['global_descriptor'] for d in db])
query = hfnet.inference(image_query)# # Perform a global search in the database
# In[5]:def compute_distance(desc1, desc2):# For normalized descriptors, computing the distance is a simple matrix multiplication.return 2 * (1 - desc1 @ desc2.T)# In[6]:nearest = np.argmin(compute_distance(query['global_descriptor'], global_index))
disp_db = [add_frame(im, (0, 255, 0)) if i == nearest else imfor i, im in enumerate(images_db)]
plot_images([image_query] + disp_db, dpi=50)# # Perform local matching with the retrieved image# In[7]:def match_with_ratio_test(desc1, desc2, thresh):dist = compute_distance(desc1, desc2)print(dist.shape)nearest = np.argpartition(dist, 2, axis=-1)[:, :2]dist_nearest = np.take_along_axis(dist, nearest, axis=-1)valid_mask = dist_nearest[:, 0] <= (thresh**2)*dist_nearest[:, 1]matches = np.stack([np.where(valid_mask)[0], nearest[valid_mask][:, 0]], 1)return matches# In[22]:matches = match_with_ratio_test(query['local_descriptors'],db[nearest]['local_descriptors'], 0.8)plot_matches(image_query, query['keypoints'],images_db[nearest], db[nearest]['keypoints'],matches, color=(0, 1, 0), dpi=50)

报错

  • 我这边召回是可以运行的,但是进行特征点匹配时,那个np.take_along_axis函数始终无法运行,应该是numpy版本问题

版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。

原文链接:https://808629.com/627.html

发表评论:

本站为非赢利网站,部分文章来源或改编自互联网及其他公众平台,主要目的在于分享信息,版权归原作者所有,内容仅供读者参考,如有侵权请联系我们删除!

Copyright © 2022 86后生记录生活 Inc. 保留所有权利。

底部版权信息