VGG提取特征向量结合ES实现以图搜图

2/5/2022 VGG卷积神经网络提取特征向量向量检索以图搜图Milvus

# 1.前言

# 1.1 以图搜图原理概述

图像检索过程简单说来就是对图片数据库的每张图片抽取特征(一般为特征向量),存储于数据库中,对于待检索图片,抽取同样的特征向量,然后并对该向量和数据库中向量的距离(相似度计算),找出最接近的一些特征向量,其对应的图片即为检索结果。

以图搜图原理

# 1.2 VGG卷积神经网络

# 1.2.1 什么是VGG

随着AlexNet在2012年的ImageNet大赛上大放异彩后,卷积神经网络进入了飞速发展的阶段。2014年,由Simonyan和Zisserman提出的VGG网络在ImageNet上取得了亚军的成绩。

VGG的命名来源于论文作者所在的实验室Visual Geometry Group,其对卷积神经网络进行了改良,探索了网络深度与性能的关系,用更小的卷积核和更深的网络结构,取得了较好的效果,成为了CNN发展史上较为重要的一个网络。

VGG中使用了一系列大小为3x3的小尺寸卷积核和池化层构造深度卷积神经网络,因为其结构简单、应用性极强而广受研究者欢迎,尤其是它的网络结构设计方法,为构建深度神经网络提供了方向。

VGG有两种结构,分别是VGG16和VGG19,两者并没有本质上的区别,只是网络深度不一样。

论文地址:https://arxiv.org/abs/1409.1556 (opens new window)

# 1.2.2 VGG基本原理

下图是VGG-16的网络结构示意图,有13层卷积和3层全连接层。VGG网络的设计严格使用3×33×3的卷积层和池化层来提取特征,并在网络的最后面使用三层全连接层,将最后一层全连接层的输出作为分类的预测。

VGG模型网络结构示意图

VGG中还有一个显著特点:每次经过池化层(maxpooling)后特征图的尺寸减小一倍,而通道数增加一倍(最后一个池化层除外)。

在VGG中每层卷积将使用ReLU作为激活函数,在全连接层之后添加dropout来抑制过拟合。使用小的卷积核能够有效地减少参数的个数,使得训练和测试变得更加有效。比如使用两层3×33×3 卷积层,可以得到感受野为5的特征图,而比使用5×55×5的卷积层需要更少的参数。由于卷积核比较小,可以堆叠更多的卷积层,加深网络的深度,这对于图像分类任务来说是有利的。VGG模型的成功证明了增加网络的深度,可以更好的学习图像中的特征模式。

# 1.2.3 VGG模型特点

  • 整个网络都使用了同样大小的卷积核尺寸3×33×3和最大池化尺寸2×22×2。
  • 1×11×1卷积的意义主要在于线性变换,而输入通道数和输出通道数不变,没有发生降维。
  • 两个3×33×3的卷积层串联相当于1个5×55×5的卷积层,感受野大小为5×55×5。同样地,3个3×33×3的卷积层串联的效果则相当于1个7×77×7的卷积层。这样的连接方式使得网络参数量更小,而且多层的激活函数令网络对特征的学习能力更强。
  • VGGNet在训练时有一个小技巧,先训练浅层的的简单网络VGG11,再复用VGG11的权重来初始化VGG13,如此反复训练并初始化VGG19,能够使训练时收敛的速度更快。
  • 在训练过程中使用多尺度的变换对原始数据做数据增强,使得模型不易过拟合。

# 1.2.4 VGG模型指标

VGG 在 2014 年的 ImageNet 比赛上取得了亚军的好成绩,具体指标如下图所示。第一行为在 ImageNet 比赛中的指标,测试集的Error rate达到了7.3%,在论文中,作者对算法又进行了一定的优化,最终可以达到 6.8% 的Error rate。

VGG模型指标

# 2. VGG提取特征向量实现以图搜图

# 2.1 VGG提取特征向量结合余弦相似度实现

可以基于VGG提取图片特征向量,然后计算余弦相似度,直接实现本地的以图搜图。

# -*- coding: utf-8 -*-

import os
import numpy as np
from keras.preprocessing import image
from keras.applications.vgg19 import VGG19
from keras.applications.vgg19 import preprocess_input


# 使用VGG19提取图片特征向量
class VGG19Net:
    def __init__(self):
        # weights: 'imagenet'
        # pooling: 'max' or 'avg'
        # input_shape: (width, height, 3), width and height should >= 48
        # self.input_shape = (224, 224, 3)
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.include_top = False
        # include_top:是否保留顶层的3个全连接网络
        # weights:None代表随机初始化,即不加载预训练权重。'imagenet' 代表加载预训练权重
        # input_tensor:可填入Keras tensor作为模型的图像输出tensor
        # input_shape:可选,仅当include_top=False有效,应为长为3的tuple,指明输入图片的shape,图片的宽高必须大于48,如(200,200,3)
        # pooling:当include_top = False时,该参数指定了池化方式。None代表不池化,最后一个卷积层的输出为4D张量。‘avg’代表全局平均池化,‘max’代表全局最大值池化。
        # classes:可选,图片分类的类别数,仅当include_top = True并且不加载预训练权重时可用。
        self.model_vgg = VGG19(include_top=self.include_top, weights=self.weight,
                               input_shape=self.input_shape, pooling=self.pooling)
        self.model_vgg.predict(np.zeros((1, 224, 224, 3)))

    # 提取vgg19最后一层卷积特征
    def my_model(self, img_path):
        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input(img)
        feat = self.model_vgg.predict(img)
        # print(feat.shape)
        # norm_feat = feat[0] / LA.norm(feat[0])
        return feat[0]

    def main_model(self, img):
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input(img)
        feat = self.model_vgg.predict(img)
        # print(feat.shape)
        # norm_feat = feat[0] / LA.norm(feat[0])
        return feat[0]


# 计算余弦相似度
def cosine_similarity(A, B):
    return np.dot(A, B) / (np.linalg.norm(A) * np.linalg.norm(B))


if __name__ == '__main__':
    vgg19_net = VGG19Net()

    # 提取目标图像的特征向量
    target_image_dir = "./images"
    target_images = [os.path.join(target_image_dir, name) for name in os.listdir(target_image_dir)]
    target_features = []
    for path in target_images:
        feat = vgg19_net.my_model(path)
        target_features.append(feat)

    # 提取查询图像的特征向量
    query_path = './query.jpg'
    query_feat = vgg19_net.my_model(query_path)

    # 计算余弦相似度并筛选结果
    threshold = 0.6
    results_with_similarity = []
    for i, target_feat in enumerate(target_features):
        similarity = cosine_similarity(query_feat, target_feat)
        if similarity >= threshold:
            results_with_similarity.append((target_images[i], similarity))

    # 按相似度从大到小排序
    sorted_results = sorted(results_with_similarity, key=lambda x: x[1], reverse=True)

    # 打印排序后的结果
    for path, similarity in sorted_results:
        print(f"Image: {path}, Similarity: {similarity}")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

# 2.2 封装VGG服务结合ES的向量检索实现

封装代码我已在Github上开源,项目地址:https://github.com/Logistic98/yoyo-algorithm/tree/master/image-feature-vector (opens new window)

VGG算法实现及生成测试数据存入ES的部分,项目结构如下:

.
├── Dockerfile
├── build.sh
├── code.py
├── extract_vgg19_keras.py
├── log.py
├── requirements.txt
├── response.py
├── server.py
└── test_code
    ├── clear_index_data.sh
    ├── create_index.sh
    ├── image-feature-vector-test.py
    ├── save_feature_vector_to_es.py
    └── test_img
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

# 2.2.1 使用Flask封装VGG特征向量提取服务

[1] VGG服务封装

extract_vgg19_keras.py

# -*- coding: utf-8 -*-

import numpy as np
from keras.preprocessing import image
from keras.applications.vgg19 import VGG19
from keras.applications.vgg19 import preprocess_input


class VGG19Net:
    def __init__(self):
        # weights: 'imagenet'
        # pooling: 'max' or 'avg'
        # input_shape: (width, height, 3), width and height should >= 48
        # self.input_shape = (224, 224, 3)
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.include_top = False
        # include_top:是否保留顶层的3个全连接网络
        # weights:None代表随机初始化,即不加载预训练权重。'imagenet' 代表加载预训练权重
        # input_tensor:可填入Keras tensor作为模型的图像输出tensor
        # input_shape:可选,仅当include_top=False有效,应为长为3的tuple,指明输入图片的shape,图片的宽高必须大于48,如(200,200,3)
        # pooling:当include_top = False时,该参数指定了池化方式。None代表不池化,最后一个卷积层的输出为4D张量。‘avg’代表全局平均池化,‘max’代表全局最大值池化。
        # classes:可选,图片分类的类别数,仅当include_top = True并且不加载预训练权重时可用。
        self.model_vgg = VGG19(include_top=self.include_top, weights=self.weight,
                               input_shape=self.input_shape, pooling=self.pooling)
        self.model_vgg.predict(np.zeros((1, 224, 224, 3)))

    # 提取vgg19最后一层卷积特征
    def my_model(self, img_path):
        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input(img)
        feat = self.model_vgg.predict(img)
        # print(feat.shape)
        # norm_feat = feat[0] / LA.norm(feat[0])
        return feat[0]

    def main_model(self, img):
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input(img)
        feat = self.model_vgg.predict(img)
        # print(feat.shape)
        # norm_feat = feat[0] / LA.norm(feat[0])
        return feat[0]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

server.py

# -*- coding: utf-8 -*-

from flask import Flask, jsonify
from flask_cors import CORS
from pre_request import pre, Rule
import base64
import numpy as np
import cv2
import extract_vgg19_keras as search

from code import ResponseCode, ResponseMessage
from log import logger

app = Flask(__name__)
CORS(app, supports_credentials=True)
model = search.VGG19Net()


def base64_cv2(base64_str):
    imgString = base64.b64decode(base64_str)
    nparr = np.fromstring(imgString, np.uint8)
    image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (224, 224))
    return image


@app.route('/imageFeatureVector/calFeatureVector', methods=['post'])
def imageFeatureVector():

    # 参数校验并获取参数
    rule = {
        "img": Rule(type=str, required=True)
    }
    try:
        params = pre.parse(rule=rule)
        image_b64 = params.get("img")
    except Exception as e:
        logger.error(e)
        fail_response = dict(code=ResponseCode.RARAM_FAIL, msg=ResponseMessage.RARAM_FAIL, data=None)
        logger.error(fail_response)
        return jsonify(fail_response)

    # 计算图片特征向量
    try:
        image_file = base64_cv2(image_b64)
        newvector = model.main_model(image_file).tolist()
    except Exception as e:
        logger.error(e)
        fail_response = dict(code=ResponseCode.BUSINESS_FAIL, msg=ResponseMessage.BUSINESS_FAIL, data=None)
        logger.error(fail_response)
        return jsonify(fail_response)

    # 成功的结果返回
    success_response = dict(code=ResponseCode.SUCCESS, msg=ResponseMessage.SUCCESS, data=newvector)
    logger.info(success_response)
    return jsonify(success_response)


if __name__ == '__main__':
    # 解决中文乱码问题
    app.config['JSON_AS_ASCII'] = False
    # 启动服务 指定主机和端口
    app.run(host='0.0.0.0', port=5004, debug=False)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

log.py

# -*- coding: utf-8 -*-

import logging

logger = logging.getLogger(__name__)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# 输出到控制台
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(formatter)
logger.addHandler(console)

# 输出到文件
logger.setLevel(level=logging.INFO)
handler = logging.FileHandler("./server.log")
handler.setLevel(logging.INFO)
handler.setFormatter(formatter)
logger.addHandler(handler)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

response.py

# -*- coding: utf-8 -*-

from code import ResponseMessage, ResponseCode


class ResMsg(object):
    """
    封装响应文本
    """
    def __init__(self, data=None, code=ResponseCode.SUCCESS, msg=ResponseMessage.SUCCESS):
        self._data = data
        self._msg = msg
        self._code = code

    def update(self, code=None, data=None, msg=None):
        """
        更新默认响应文本
        :param code:响应状态码
        :param data: 响应数据
        :param msg: 响应消息
        :return:
        """
        if code is not None:
            self._code = code
        if data is not None:
            self._data = data
        if msg is not None:
            self._msg = msg

    def add_field(self, name=None, value=None):
        """
        在响应文本中加入新的字段,方便使用
        :param name: 变量名
        :param value: 变量值
        :return:
        """
        if name is not None and value is not None:
            self.__dict__[name] = value

    @property
    def data(self):
        """
        输出响应文本内容
        :return:
        """
        body = self.__dict__
        body["data"] = body.pop("_data")
        body["msg"] = body.pop("_msg")
        body["code"] = body.pop("_code")
        return body
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

code.py

# -*- coding: utf-8 -*-


class ResponseCode(object):
    SUCCESS = 200
    RARAM_FAIL = 400
    BUSINESS_FAIL = 500


class ResponseMessage(object):
    SUCCESS = "请求成功"
    RARAM_FAIL = "参数校验失败"
    BUSINESS_FAIL = "业务处理失败"
1
2
3
4
5
6
7
8
9
10
11
12
13

[2] 编写测试程序

image-feature-vector-test.py

# -*- coding: utf-8 -*-

import base64
import requests
import json


def image_feature_vector_test():
    # 测试请求
    url = 'http://{0}:{1}/imageFeatureVector/calFeatureVector'.format("127.0.0.1", "5004")
    f = open('./test_img/001.jpg', 'rb')
    # base64编码
    base64_data = base64.b64encode(f.read())
    f.close()
    base64_data = base64_data.decode()
    # 传输的数据格式
    data = {'img': base64_data}
    # post传递数据
    headers = {'Content-Type': 'application/json'}
    r = requests.post(url, headers=headers, data=json.dumps(data))
    print(r.text)


if __name__ == '__main__':
    image_feature_vector_test()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

[3] 编写部署脚本

requirements.txt

opencv_python
pre-request==2.1.5
numpy==1.20.1
Flask==2.0.2
Flask_Cors==3.0.6
keras==2.6.0
requests==2.26.0
1
2
3
4
5
6
7

Dockerfile

# 基于python3.7镜像创建新镜像
FROM python:3.7
# 创建容器内部目录
RUN mkdir /code
# 将项目复制到内部目录
ADD . /code/
# 切换到工作目录
WORKDIR /code
# 安装项目依赖
RUN pip install -r requirements.txt
RUN pip install https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-2.6.0-cp37-cp37m-manylinux2010_x86_64.whl
RUN apt update && apt install libgl1-mesa-glx -y
# 放行端口
EXPOSE 5004
# 启动项目
ENTRYPOINT ["nohup","python","server.py","&"]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

build.sh

docker build -t image-feature-vector-image .
docker run -d -p 5004:5004 --name image-feature-vector image-feature-vector-image:latest
docker update image-feature-vector --restart=always
1
2
3

# 2.2.2 生成测试数据并存入ES作为检索库

[1] ElasticSearch向量检索

ES 7.X 版本引入了向量类型dense_vector,用于存储浮点类型的密集向量,其最大维度为2048。其用作是可以将待查询向量和文档内存储向量之间的距离作为查询评分使用,即越相似的向量评分越高。使用方式为在 query 的script_score中指定向量的计算方式,具体有四种:

cosineSimilarity – 余弦函数
dotProduct – 向量点积
l1norm – 曼哈顿距离
l2norm - 欧几里得距离
1
2
3
4

[2] 创建ElasticSearch索引

create_index.sh

curl -u your_user:your_password -XPUT 'http://127.0.0.1:9200/search_by_image_index' -H 'Content-Type: application/json' -d '{ "mappings": { "properties": { "file_path":{ "type": "keyword" }, "feature_vector":{ "type": "dense_vector", "dims": 512 } } } }'
1

用于测试ES向量检索实现以图搜图的索引

注:用于存储图片特征向量的那个字段,类型必须是dense_vector,需要手动指定,不能用插入时的默认类型,不然查询时会出错。

[3] 生成测试数据并存入ES

save_feature_vector_to_es.py

# -*- coding: utf-8 -*-

import base64
import logging

from elasticsearch import Elasticsearch, helpers
import requests
import json
import os

# 生成日志文件
logging.basicConfig(filename='save_feature_vector_to_es.log', level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def find_filepaths(dir):
    """
    级联遍历目录,获取目录下的所有文件路径
    """
    result = []
    for root, dirs, files in os.walk(dir):
        for name in files:
            filepath = os.path.join(root, name)
            if os.path.exists(filepath):
                result.append(filepath)
    return result


def write_list_to_json(list, json_file_name, json_file_save_path):
    """
    将list写入到json文件
    :param list:
    :param json_file_name: 写入的json文件名字
    :param json_file_save_path: json文件存储路径
    :return:
    """
    if not os.path.exists(json_file_save_path):
        os.makedirs(json_file_save_path)
    os.chdir(json_file_save_path)
    with open(json_file_name, 'w', encoding='utf-8') as f:
        json.dump(list, f, ensure_ascii=False)


# 将构造好的json列表写入ES数据库
def save_data_to_es(json_list, index_name):
    Es = Elasticsearch(
        hosts=['127.0.0.1:9200'],
        http_auth=('your_user', 'your_password'),
        timeout=60
    )
    # 按照步长分批插入数据库,缓解插入数据库时的压力
    length = len(json_list)
    # 步长为1000,缓解批量写入的压力
    step = 1000
    for i in range(0, length, step):
        # 要写入的数据长度大于步长,那么久分批写入
        if i + step < length:
            actions = []
            for j in range(i, i + step):
                action = {
                    "_index": str(index_name),
                    "_source": json_list[j]
                }
                actions.append(action)
            helpers.bulk(Es, actions, request_timeout=120)
        # 要写入的数据小于步长,那么久一次性写入
        else:
            actions = []
            for j in range(i, length):
                action = {
                    "_index": str(index_name),
                    "_source": json_list[j]
                }
                actions.append(action)
            helpers.bulk(Es, actions, request_timeout=120)


# 调用VGG算法获取图片特征向量
def image_feature_vector(img_path, vgg_url):
    f = open(img_path, 'rb')
    # base64编码
    base64_data = base64.b64encode(f.read())
    f.close()
    base64_data = base64_data.decode()
    # 传输的数据格式
    data = {'img': base64_data}
    # post传递数据
    headers = {'Content-Type': 'application/json'}
    r = requests.post(vgg_url, headers=headers, data=json.dumps(data))
    return json.loads(r.text)['data']


if __name__ == '__main__':

    img_base_path = './test_img'
    img_path_list = find_filepaths(img_base_path)
    vgg_url = 'http://{0}:{1}/imageFeatureVector/calFeatureVector'.format("127.0.0.1", "5004")
    json_list = []
    for img_path in img_path_list:
        feature_vector = image_feature_vector(img_path, vgg_url)
        # print(feature_vector)
        json_item = {}
        json_item['file_path'] = img_path
        json_item['feature_vector'] = feature_vector
        json_list.append(json_item)
    index_name = 'search_by_image_index'
    save_data_to_es(json_list, index_name)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

在test_img目录下放置用于生成检索库的测试图片,执行上面的python程序即可。

用于测试ES向量检索实现以图搜图的数据

如果需要清空数据(不删除索引),可以使用如下脚本:

clear_index_data.sh

curl -u your_user:your_password -XPOST 'http://127.0.0.1:9200/search_by_image_index/_delete_by_query?refresh&slices=5&pretty' -H 'Content-Type: application/json' -d'{"query": {"match_all": {}}}'
1

# 2.2.3 在业务系统中使用ES实现以图搜图

关于Elasticsearch的搭建及基本使用,这里就不赘述了,如果不会的话见我的另一篇博客:Elasticsearch整合及基本操作示例 (opens new window)

[1] 使用ES向量检索实现以图搜图

application.properties 添加VGG算法服务的配置

settings.service.vgg-url=http://127.0.0.1:5004/imageFeatureVector/calFeatureVector
1

ElasticSearchService.java 添加向量检索的通用封装

    /**
     * 根据图片特征向量实现以图搜图(字段类型需为dense_vector)
     * @param indexName
     * @param scriptScoreQueryBuilder
     * @param score
     * @param page
     * @param rows
     * @param includeFields
     * @param excludeFields
     * @return
     */
    public SearchResponse imageSearch(String indexName, ScriptScoreQueryBuilder scriptScoreQueryBuilder, Float score, Integer page, Integer rows, String[] includeFields, String[] excludeFields) {
        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().trackTotalHits(true);
        sourceBuilder.query(scriptScoreQueryBuilder);
        sourceBuilder.from((page - 1) * rows);
        sourceBuilder.size(rows);
        sourceBuilder.minScore(score);
        sourceBuilder.timeout(new TimeValue(120, TimeUnit.SECONDS));
        sourceBuilder.fetchSource(includeFields, excludeFields);
        return pageQuerySearchResponse(sourceBuilder, indexName);
    }
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

ElasticSearchDemoController.java 添加以图搜图业务实现的Controller

    @ApiOperation("根据图片特征向量实现以图搜图")
    @ApiImplicitParams({
            @ApiImplicitParam(name = "file", value = "图片文件", dataType = "file", paramType = "query"),
            @ApiImplicitParam(name = "score", value = "相似度值", dataType = "float", paramType = "query"),
            @ApiImplicitParam(name = "page", value = "第几页", dataType = "Integer", paramType = "query"),
            @ApiImplicitParam(name = "pageSize", value = "分页大小", dataType = "Integer", paramType = "query")
    })
    @RequestMapping(value = "/searchByImage", method = RequestMethod.POST)
    public ResponseEntity<?> searchByImage(@RequestParam MultipartFile file,
                                           @RequestParam(defaultValue = "0.8") Float score,
                                           @RequestParam(defaultValue = "1") Integer page,
                                           @RequestParam(defaultValue = "10") Integer pageSize) {
        try {
            return ResultDataUtils.success(elasticSearchDemoService.searchByImage(file, score, page, pageSize));
        } catch (Exception ex) {
            ex.printStackTrace();
            return ResultDataUtils.error(ex.getMessage());
        }
    }
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

ElasticSearchDemoService.java 添加以图搜图业务实现的Service

    @Value("${settings.service.vgg-url}")
    private String vggUrl;
    
    ...
    
    /**
     * 根据图片特征向量实现以图搜图
     * @param file
     * @param score
     * @param page
     * @param pageSize
     * @return
     */
    public HashMap<String, Object> searchByImage(MultipartFile file, Float score, Integer page, Integer pageSize) {

        String base64 = null;
        JSONObject paramObject = new JSONObject();
        try {
            base64 = Base64.encode(file.getInputStream());
        } catch (IOException e) {
            e.printStackTrace();
            System.out.println("上传图片转base64错误");
        }
        if (ObjectUtil.isNotEmpty(base64)) {
            paramObject.put("img", base64);
        }

        String vggResult = "";
        try {
            vggResult = HttpUtil.post(vggUrl, paramObject.toString());
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println("获取上传图片的特征向量错误");
        }

        SearchResponse searchResponse = null;
        try {
            if (vggResult.contains("code") && vggResult.contains("data")) {
                JSONObject resultObject = JSONUtil.parseObj(vggResult);
                JSONArray featureVector = resultObject.getJSONArray("data");
                HashMap<String, Object> esParams = new HashMap();
                esParams.put("queryVector", featureVector);
                Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "cosineSimilarity(params.queryVector, 'feature_vector')", esParams);
                ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(QueryBuilders.existsQuery("feature_vector"), script);
                searchResponse = elasticSearchService.imageSearch("search_by_image_index", scriptScoreQueryBuilder, score, page, pageSize, null, null);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

        SearchHit[] searchHits = searchResponse.getHits().getHits();
        HashMap<String, Object> result = new HashMap<>();
        Long total = searchResponse.getHits().getTotalHits().value;
        List<Map<String, Object>> dataList = new ArrayList<>();
        for (SearchHit searchHit : searchHits) {
            Map<String, Object> sourceAsMap = searchHit.getSourceAsMap();
            dataList.add(sourceAsMap);
        }
        result.put("dataList", dataList);
        result.put("total", total);
        return result;

    }
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

[2] 测试以图搜图功能的实现

从原先的检索底库里找张图,简单处理一下,然后拿来测试。接口共接收4个参数,其中file是图片文件,score是相似度值,另外两个是分页参数。

测试以图搜图接口

# 3. 基于Milvus搭建以图搜图系统

# 3.1 Milvus基本介绍

# 3.1.1 什么是Milvus

Milvus 是一款云原生向量数据库,它具备高可用、高性能、易拓展的特点。Milvus 基于 FAISS、Annoy、HNSW 等向量搜索库构建,核心是解决稠密向量相似度检索的问题。在向量检索库的基础上,Milvus 支持数据分区分片、数据持久化、增量数据摄取、标量向量混合查询、time travel 等功能,同时大幅优化了向量检索的性能,可满足任何向量检索场景的应用需求。通常,建议用户使用 Kubernetes 部署 Milvus,以获得最佳可用性和弹性。

Milvus 采用共享存储架构,存储计算完全分离,计算节点支持横向扩展。从架构上来看,Milvus 遵循数据流和控制流分离,整体分为了四个层次,分别为接入层(access layer)、协调服务(coordinator service)、执行节点(worker node)和存储层(storage)。各个层次相互独立,独立扩展和容灾。

Milvus工作流程

# 3.1.2 为什么选择Milvus

Milvus 向量数据库专为向量查询与检索设计,能够为万亿级向量数据建立索引,它具备以下特点。

  • 高性能:性能高超,可对海量数据集进行向量相似度检索。
  • 高可用、高可靠:Milvus 支持在云上扩展,其容灾能力能够保证服务高可用。
  • 混合查询:Milvus 支持在向量相似度检索过程中进行标量字段过滤,实现混合查询。
  • 开发者友好:支持多语言、多工具的 Milvus 生态系统。

# 3.1.3 Milvus系统架构

Milvus 2.0 是一款云原生向量数据库,采用存储与计算分离的架构设计,所有组件均为无状态组件,极大地增强了系统弹性和灵活性。

Milvus系统架构

整个系统分为四个层次,各个层次相互独立,独立扩展和容灾。

  • 接入层(Access Layer):系统的门面,由一组无状态 proxy 组成。对外提供用户连接的 endpoint,负责验证客户端请求并合并返回结果。
  • 协调服务(Coordinator Service):系统的大脑,负责分配任务给执行节点。协调服务共有四种角色,分别为 root coord、data coord、query coord 和 index coord。
  • 执行节点(Worker Node):系统的四肢,负责完成协调服务下发的指令和 proxy 发起的数据操作语言(DML)命令。执行节点分为三种角色,分别为 data node、query node 和 index node。
  • 存储服务 (Storage): 系统的骨骼,负责 Milvus 数据的持久化,分为元数据存储(meta store)、消息存储(log broker)和对象存储(object storage)三个部分。

# 3.1.4 Milvus应用场景

可以使用 Milvus 搭建符合自己场景需求的向量相似度检索系统,Milvus 的使用场景如下所示:

  • 图片检索系统:以图搜图,从海量数据库中即时返回与上传图片最相似的图片。
  • 视频检索系统:将视频关键帧转化为向量并插入 Milvus,便可检索相似视频,或进行实时视频推荐。
  • 音频检索系统:快速检索海量演讲、音乐、音效等音频数据,并返回相似音频。
  • 分子式检索系统:超高速检索相似化学分子结构、超结构、子结构。
  • 推荐系统:根据用户行为及需求推荐相关信息或商品。
  • 智能问答机器人:交互式智能问答机器人可自动为用户答疑解惑。
  • DNA 序列分类系统:通过对比相似 DNA 序列,仅需几毫秒便可精确对基因进行分类。
  • 文本搜索引擎:帮助用户从文本数据库中通过关键词搜索所需信息。

# 3.2 基于Milvus搭建以图搜图系统

# 3.2.1 搭建以图搜图系统

基于Milvus搭建以图搜图系统的步骤较为繁琐,这里就不自己搭建了,直接使用 苏洋 (opens new window) 大佬搭建好的镜像。

简化后的系统架构

使用以下命令即可一键搭建一个本地的图片搜索引擎,实现快速的以图搜图:

$ docker run -itd --name=milvus -p 3000:3000 -v `pwd`/images:/images soulteary/image-search-app:2.1.0
1

Chrome访问http://ip:3000地址即可访问。

基于Milvus的以图搜图系统

# 3.2.2 使用以图搜图系统

在服务器的 images 目录上传图片底库。之后点击页面上的“+”,页面会自动变灰,提示我们正在使用模型对图片进行编码,以及将抽取的特征向量存入Milvus里。这里变灰时间和我们本地机器的性能、刚刚在文件夹内放置图片数有关,设备性能越强,图片数据相对少,可以减少等待的时间。如果要清空底库,点击“CLEAR ALL”即可。

底库图片预处理之后 xxx images in this set处会显示出具体的底库数量,show top xxx results处我们拖动选择保留几张最相似的图片,之后点击 click to upload / drag a image here上传图片即可,右侧便会根据相似度显示出相似图片。

使用Milvus以图搜图系统

# 4. 参考资料

[1] 基于VGG-16的海量图像检索系统 from 博客园 (opens new window)

[2] 《深度学习》之 VGG卷积神经网络 原理 详解 from CSDN (opens new window)

[3] 深度学习之VGG19模型简介 from 台部落 (opens new window)

[4] VGG网络原理介绍 from CSDN (opens new window)

[5] VGG基本介绍 from PP飞桨 (opens new window)

[6] Pytorch基于VGG cosine similarity实现简单的以图搜图(图像检索) from CSDN (opens new window)

[7] ES 向量检索 dense_vector 类型 from CSDN (opens new window)

[8] Milvus实战 | 轻松搭建以图搜图系统 from 知乎 (opens new window)

[9] 向量数据库入坑:使用 Docker 和 Milvus 快速构建本地轻量图片搜索引擎 from 苏洋博客 (opens new window)

[10] 云原生向量数据库Milvus(一)-简述、系统架构及应用场景 from 稀土掘金 (opens new window)

Last Updated: 2/16/2024, 9:20:25 PM