共计 2393 个字符,预计需要花费 6 分钟才能阅读完成。
有一个需要:计算图片的类似度
须要解决两个问题:
- 生成 ahash
- 存储和计算 ahash 之间的间隔
生成 ahash
『生成 ahash』选用 python 上面的一个 imagehash 库。(github:https://github.com/JohannesBu…)
from io import BytesIO | |
import numpy | |
import imagehash | |
from PIL import Image | |
def create_vector(file: BytesIO) -> bytes: | |
image = Image.open(file) | |
hash = imagehash.average_hash(image) | |
_vector = [] | |
for h in hash.hash: | |
_vector.extend(h) | |
vector = bytes( | |
numpy.packbits( | |
[int(v) | |
for v in _vector | |
], | |
axis=-1 | |
).tolist()) | |
return vector |
create_vector 函数输入的类型是 bytes,就是二进制序列
imagehash.average_hash(image) 输入的 hash 对象,hash 对象有一个 hash 属性,这个属性的类型是
list[list[bool]]
打印进去就是长上面这样子,其实就是一个 8×8=64 bit 的序列
[[False False False False False False False False] [True False False False True False False False] [False False True True True True False False] [False False False True True False True True] [False False True True True False False False] [False True True True True False False False] [False True True True True False True True] [False False False True True False True True]]
向量数据库
『存储和计算 ahash 之间的间隔』选用 milvus
创立汇合
定义汇合:
import settings | |
from pymilvus import ( | |
connections, | |
Collection, | |
FieldSchema, | |
CollectionSchema, | |
DataType, | |
) | |
from loggers import logger | |
connections.connect( | |
host=settings.MILVUS_CONFIG.host, | |
port=settings.MILVUS_CONFIG.port, | |
) | |
schema = CollectionSchema([FieldSchema("id", DataType.INT64, is_primary=True, auto_id=True), | |
FieldSchema("meta_id", DataType.INT64), | |
FieldSchema("company_id", DataType.INT64), | |
FieldSchema("image_vector", dtype=DataType.BINARY_VECTOR, dim=64) | |
]) | |
# 汇合不存在,则会主动创立汇合;已存在,不会反复创立 | |
collection = Collection(settings.MILVUS_CONFIG.collection.name, schema) |
应用的向量类型是 dtype=DataType.BINARY_VECTOR
,
为什么不选 float 是因为我不晓得怎么把 ahash 转成 float
插入 ahash 到 milvus
class TestVector(unittest.TestCase): | |
def test_insert_vector(self): | |
""" | |
插入 ahash 到 milvus | |
python -m unittest testing.test_milvus.TestVector.test_insert_vector | |
"""oss_file_path ='image_hash/testing/WechatIMG193.jpeg' | |
file = BytesIO(bucket.get_object(oss_file_path).read()) | |
vector = create_vector(file) | |
m_pk = insert_vector(vector, meta_id=2, company_id=1) | |
logger.debug(f'milvus pk: {m_pk}') |
查问 ahash from milvus
def test_search(self): | |
""" | |
批量调用后端接口入库 | |
python -m unittest testing.test_milvus.TestVector.test_search | |
"""oss_file_path ='image_hash/testing/WechatIMG193.jpeg'file = BytesIO(open(BASE_DIR/'testing'/'resource'/'WechatIMG193.jpeg','rb').read()) | |
vector = create_vector(file) | |
logger.debug(vector) | |
rows: list[dict[str, Any]] = collection.search(data=[vector], | |
param={"metric_type": 'L2', "params": {"nprobe": 32}}, | |
anns_field='image_vector', | |
output_fields=['id', 'meta_id', 'company_id'], | |
limit=10, | |
) | |
logger.debug(rows) | |
logger.debug(type(rows)) |
正文完