乐趣区

关于人工智能:AI-Milvus将时尚应用搭建进行到底

在上一篇文章中,咱们学习了如何利用人工智能技术(例如开源 AI 向量数据库 Milvus 和 Hugging Face 模型)寻找与本人穿搭格调类似的明星。在这篇文章中,咱们将进一步介绍如何通过对上篇文章中的我的项目代码稍作批改,取得更具体和精确的后果,文末附赠彩蛋。

注:试用此我的项目利用,须要点击下载并应用 notebook

01. 回顾前文

在深入探讨前,先简要回顾一下前一篇教程文章。

导入所需的图像处理库和工具

首先导入所有必要的图像处理库,包含用于特征提取的 torchtransformers 中的 segformer 对象、matplotlibtorchvision 中的 Resizemasks_to_boxescrop 等。

import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop

预处理明星照片

在导入所有必要的图像处理库和工具后,就能够开始解决图像。以下三个函数 get_segmentationget_maskscrop_images 用于宰割并裁剪图片中的时尚单品,以供后续应用。

import torch
def get_segmentation(extractor, model, image):
    inputs = extractor(images=image, return_tensors="pt")

    outputs = model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]
    return pred_seg

# 返回两个 masks(tensor)列表和 obj_ids(int)# 来自 Hugging Face 的 mattmdjaga/segformer_b2_clothes 模型
def get_masks(segmentation):
    obj_ids = torch.unique(segmentation)
    obj_ids = obj_ids[1:]
    masks = segmentation == obj_ids[:, None, None]
    return masks, obj_ids

def crop_images(masks, obj_ids, img):
    boxes = masks_to_boxes(masks)
    crop_boxes = []
    for box in boxes:
        crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
        crop_boxes.append(crop_box)
    preprocess = T.Compose([T.Resize(size=(256, 256)),
        T.ToTensor()])
    cropped_images = {}
    for i in range(len(crop_boxes)):
        crop_box = crop_boxes[i]
        cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
        cropped_images[obj_ids[i].item()] = preprocess(cropped)
    return cropped_images

将图像数据存储到向量数据库中

抉择开源向量数据库 Milvus 来存储图像数据。开始前,须要先解压蕴含照片的 zip 文件,并在 notebook 雷同的根目录中创立照片文件夹。实现后,能够运行以下代码来将图像数据存储在 Milvus 中。

import os
image_paths = []
for celeb in os.listdir("./photos"):
    for image in os.listdir(f"./photos/{celeb}/"):
        image_paths.append(f"./photos/{celeb}/{image}")

from milvus import default_server
from pymilvus import utility, connections
default_server.start()
connections.connect(host="127.0.0.1", port=default_server.listen_port)
DIMENSION = 2048
BATCH_SIZE = 128
COLLECTION_NAME = "fashion"
TOP_K = 3
from pymilvus import FieldSchema, CollectionSchema, Collection, DataType

fields = [FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="name", dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="seg_id", dtype=DataType.INT64),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]

schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)
index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

接着,运行以下代码,应用来自 Hugging Face 的 Nvidia ResNet 50 模型生成 embedding 向量。

# 如遇 SSL 证书 URL 谬误,请在导入 resnet50 模型前运行此步骤
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 并删除最初一层模型输入
embeddings_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
embeddings_model = torch.nn.Sequential(*(list(embeddings_model.children())[:-1]))
embeddings_model.eval()

以下函数定义了如何将图像转换为向量并插入到 Milvus 向量数据库中。代码会循环遍历所有图像。(留神:如果须要开启 Milvus 全新个性动静 Schema,须要批改代码。)

def embed_insert(data, collection, model):
    with torch.no_grad():
        output = model(torch.stack(data[0])).squeeze()
        collection.insert([data[1], data[2], data[3], output.tolist()])
from PIL import Image
data_batch = [[], [], [], []]

for path in image_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = "".join(path_split[2].split("_"))
    segmentation = get_segmentation(extractor, model, image)
    masks, ids = get_masks(segmentation)
    cropped_images = crop_images(masks, ids, image)for key, image in cropped_images.items():
        data_batch[0].append(image)
        data_batch[1].append(path)
        data_batch[2].append(name)
        data_batch[3].append(key)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed_insert(data_batch, collection, embeddings_model)
            data_batch = [[], [], [], []]

if len(data_batch[0]) != 0:
    embed_insert(data_batch, collection, embeddings_model)

collection.flush()

查问向量数据库

以下代码演示了如何应用输出图像查问 Milvus 向量数据库,以检索和上传衣服图像最类似的的前三个后果。

def embed_search_images(data, model):
    with torch.no_grad():
    output = model(torch.stack(data))
    if len(output) > 1:
        return output.squeeze().tolist()
    else:
        return torch.flatten(output, start_dim=1).tolist()
# data_batch[0] 是 tensor 列表
# data_batch[1] 是图像文件的文件门路(字符串)# data_batch[2] 是图像中人物的名称列表(字符串)# data_batch[3] 是宰割键值列表(int)data_batch = [[], [], [], []]

search_paths = ["./photos/Taylor_Swift/Taylor_Swift_3.jpg", "./photos/Taylor_Swift/Taylor_Swift_8.jpg"]

for path in search_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = "".join(path_split[2].split("_"))
    segmentation = get_segmentation(extractor, model, image)
    masks, ids = get_masks(segmentation)
    cropped_images = crop_images(masks, ids, image)
    for key, image in cropped_images.items():
        data_batch[0].append(image)
        data_batch[1].append(path)
        data_batch[2].append(name)
        data_batch[3].append(key)

embeds = embed_search_images(data_batch[0], embeddings_model)
import time
start = time.time()
res = collection.search(embeds,
    anns_field='embedding',
    param={"metric_type": "L2",
        "params": {"nprobe": 10}},
    limit=TOP_K,
    output_fields=['filepath'])
finish = time.time()
print(finish - start)
for index, result in enumerate(res):
    print(index)
    print(result)

02. 匹配更多格调:标示每张图像中的时尚单品

除了间接应用上述代码,查找与你着装格调最类似的 3 位明星以外,咱们还能够略微批改一下代码,拓展我的项目的利用场景。能够批改代码获取如下所示,不蕴含边界框的图像。

接下来,将为大家介绍如何批改上述代码寻找更多匹配的穿衣格调。

导入所需的图像处理库和工具

同样,须要先导入所有必要的图像处理库。如果曾经实现导入,请跳过此步骤。

import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop

预处理图像

这个步骤波及三个函数:get_segmentationget_maskscrop_images

无需批改 get_segmentation 函数局部的代码。

对于 get_masks 函数,只须要获取与 wanted 列表中的宰割 ID 绝对应的宰割图像即可。

crop_image 函数做出更改。在前一篇文的教程中,此函数返回裁剪图像的列表。这里,咱们进行一些调整,使函返回三个对象:裁剪图像对应的 embedding 向量、边界框在原始图像上的坐标列表,以及宰割 ID 列表。这一改变将转化 embedding 向量的步骤提前了。

wanted = [1, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17]
def get_segmentation(image):
    inputs = extractor(images=image, return_tensors="pt")

    outputs = segmentation_model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]
    return pred_seg

# returns two lists masks (tensor) and obj_ids (int)
# "mattmdjaga/segformer_b2_clothes" from hugging face
def get_masks(segmentation):
    obj_ids = torch.unique(segmentation)
    obj_ids = obj_ids[1:]
    wanted_ids = [x.item() for x in obj_ids if x in wanted]
    wanted_ids = torch.Tensor(wanted_ids)
    masks = segmentation == wanted_ids[:, None, None]
    return masks, obj_ids

def crop_images(masks, obj_ids, img):
    boxes = masks_to_boxes(masks)
    crop_boxes = []
    for box in boxes:
        crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
        crop_boxes.append(crop_box)
    preprocess = T.Compose([T.Resize(size=(256, 256)),
        T.ToTensor()])
    cropped_images = []
    seg_ids = []
    for i in range(len(crop_boxes)):
        crop_box = crop_boxes[i]
        cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
        cropped_images.append(preprocess(cropped))
        seg_ids.append(obj_ids[i].item())
    with torch.no_grad():
        embeddings = embeddings_model(torch.stack(cropped_images)).squeeze().tolist()
    return embeddings, boxes.tolist(), seg_ids

有了图像数据之后,就能够加载数据了。这一步骤须要应用到批量插入性能,上篇文章的教程中也有波及,但不同点在于,本文的教程中将数据作为 dictionary 列表一次性插入。这种插入方式更简洁,同时还容许咱们在插入数据时动静新增 Schema 字段。

for path in image_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = "".join(path_split[2].split("_"))
    segmentation = get_segmentation(image)
    masks, ids = get_masks(segmentation)
    embeddings, crop_corners, seg_ids = crop_images(masks, ids, image)
    inserts = [{"embedding": embeddings[x], "seg_id": seg_ids[x], "name": name, "filepath": path, "crop_corner": crop_corners[x]} for x in range(len(embeddings))]
    collection.insert(inserts)
    collection.flush()

查问向量数据库

当初能够开始在向量数据库 Milvus 中查问数据了。本文与上篇文章的教程有以下几点区别:

  • 将一张图像中匹配的时尚单品数量限度到 5 件。
  • 指定查问返回最类似的 3 张图像。
  • 增加函数获取图片的色调图。

随后,在 matplotlib 中设置 figures 和 axes,代码会循环遍历所有图像,将上文的 3 个函数利用到所有图像上,以获取宰割后果和边界框。

查问数据时,能够依据每张图像中匹配的时尚单品数量来取得最类似的 3 张图像。

最终返回的后果图像中会带有标示出匹配单品的边界框。

from pprint import pprint
from PIL import ImageDraw
from collections import Counter
import matplotlib.patches as patches

LIMIT = 5 # 每张图像中匹配的时尚单品件数
CLOSEST = 3 # 返回的最类似图像数量。CLOSEST <= Limit

search_paths = ["./photos/Taylor_Swift/Taylor_Swift_2.jpg", "./photos/Jenna_Ortega/Jenna_Ortega_6.jpg"] # Images to search fordef get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
    RGB color; the keyword argument name must be a standard mpl colormap name.
    Sourced from <https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib>'''return plt.cm.get_cmap(name, n)

# 创立后果 subplot
f, axarr = plt.subplots(max(len(search_paths), 2), CLOSEST + 1)

for search_i, path in enumerate(search_paths):
    # Generate crops and embeddings for all items found
    image = Image.open(path)
    segmentation = get_segmentation(image)
    masks, ids = get_masks(segmentation)
    embeddings, crop_corners, _ = crop_images(masks, ids, image)

# 生成色调图
    cmap = get_cmap(len(crop_corners))

    # Display the first box with image being searched for
    axarr[search_i][0].imshow(image)
    axarr[search_i][0].set_title('Search Image')
    axarr[search_i][0].axis('off')
    for i, (x0, y0, x1, y1) in enumerate(crop_corners):
        rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(i), facecolor='none')
        axarr[search_i][0].add_patch(rect)

    # 查问向量数据库
    start = time.time()
    res = collection.search(embeddings,
        anns_field='embedding',
        param={"metric_type": "L2",
        "params": {"nprobe": 10}, "offset": 0},
        limit=LIMIT,
        output_fields=['filepath', 'crop_corner'])
    finish = time.time()

    print("Total Search Time:", finish - start)

    # 依据地位给查问后果减少不同的权重
    filepaths = []
    for hits in res:
        seen = set()
        for i, hit in enumerate(hits):
            if hit.entity.get("filepath") not in seen:
                seen.add(hit.entity.get("filepath"))
                filepaths.extend([hit.entity.get("filepath") for _ in range(len(hits) - i)])
    # 查找排名最高的图像
    counts = Counter(filepaths)
    most_common = [path for path, _ in counts.most_common(CLOSEST)]
    # 提取每张图像中与查问图像相干的时尚单品
    matches = {}
    for i, hits in enumerate(res):
        matches[i] = {}
        tracker = set(most_common)
        for hit in hits:
            if hit.entity.get("filepath") in tracker:
                matches[i][hit.entity.get("filepath")] = hit.entity.get("crop_corner")
                tracker.remove(hit.entity.get("filepath"))
        # 返回最类似图像:# 返回与查问图像邻近的图像
        image = Image.open(res_path)
        axarr[search_i][res_i+1].imshow(image)
        axarr[search_i][res_i+1].set_title("".join(res_path.split("/")[2].split("_")))
        axarr[search_i][res_i+1].axis('off')
# 为匹配单品增加边界框
        if res_path in value:
            x0, y0, x1, y1 = value[res_path]
            rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(key), facecolor='none')
            axarr[search_i][res_i+1].add_patch(rect)

运行上述步骤后,后果如下所示:

03. 我的项目后续:摸索更多利用场景

欢送大家基于本我的项目拓展更多、更丰盛的利用场景,例如:

  • 进一步延长比照性能,例如将不同的单品归类到一起。同样,也能够上传更多图像到数据库中,丰盛查问后果。
  • 将本我的项目转变为时尚探测仪或者时尚举荐零碎。例如,将明星图像替换成可购买的衣服图像。这样一来,用户上传照片后,能够查问与他的衣服格调类似的其余衣服。
  • 还能够基于本我的项目搭建一个穿搭生成零碎,很多办法都能够实现这个利用,但这个利用的搭建相对而言更有难度!本文提供了一种思路,零碎能够依据用户上传的多张照片相应举荐穿搭。这里须要用到生成式图像模型,从而提供穿搭倡议。

总之,不要限度你的想象力,搭建更丰盛的利用。Milvus 之类的向量数据库为相似性搜寻利用提供了有限可能。

04. 总结

本文教程中,咱们进一步拓展了时尚 AI 我的项目的利用场景。

本次教程应用了 Milvus 全新的 动静 Schema 性能,筛选了宰割 ID,在返回图像中保留了边界框。同时,咱们在查问中指定 Milvus 依据每张图像中匹配的时尚单品件数返回最类似的 3 张图像。Milvus 全新的动静 Schema 性能反对在上传数据时增加新的字段,扭转了咱们批量上传数据的形式。应用这个性能后,在上传数据时,无需改变 Schema 即可增加裁剪。在图像预处理步骤中,剔除了一些辨认到的非着装类元素。同时,本教程保留了边界框,将转化向量的步骤提前至了裁剪图片的步骤。

当然,通过进一步调整代码,咱们还能够搭建更多相干利用,例如:时尚举荐零碎、帮忙用户搭配着装的零碎,甚至是生成式的时尚 AI 利用!

🌟「寻找 AIGC 时代的 CVP 实际之星」专题流动行将启动!

Zilliz 将联合国内头部大模型厂商一起甄选利用场景,由单方提供向量数据库与大模型顶级技术专家为用户赋能,一起打磨利用,晋升落地成果,赋能业务自身。

如果你的利用也适宜 CVP 框架,且正为利用落地和实际效果发愁,可间接申请参加流动,取得最业余的帮忙和领导!分割邮箱为 business@zilliz.com。

本文由 mdnice 多平台公布

退出移动版