关于challenge:详解PyG中的ToSLIC变换

17次阅读

共计 4475 个字符,预计需要花费 12 分钟才能阅读完成。

详解 PyG 中的 ToSLIC 变换

PyG 是一个基于 PyTorch 的图神经网络库,提供了丰盛的数据处理、图转换和图模型的性能。本文将介绍 PyG 中的一个图转换函数 ToSLIC,它能够将一张图片转换为一个超像素图,并生成相应的数据对象。

前言

PyG 是一个开源的 Python 库,用于深度学习工作中的图神经网络(GNN)建模和训练。该库包含多个 GNN 模型和与图相干的数据结构和算法。在本篇文章中,我将介绍 PyG 中的 ToSLIC 模块,它是一个用于图像宰割的超像素宰割算法。

什么是超像素图?

图像宰割是计算机视觉畛域的一个重要问题,它的指标是将图像分成若干个类似的区域,每个区域具备肯定的语义信息。图像宰割在许多畛域都有利用,如医学影像、主动驾驶、图像检索等。

超像素宰割是一种常见的图像宰割办法,它将图像中的像素划分为若干个类似的区域,这些区域被称为超像素。与像素相比,超像素更具备代表性和可解释性,并且能够升高图像宰割的复杂度。

超像素图是一种对图片进行宰割的办法,将图片中类似的像素聚合成一个个小区域,称为超像素。每个超像素能够看作是图片中的一个节点,它具备肯定的特色(如色彩、地位等),并与其余超像素有肯定的关系(如邻接、间隔等)。这样,咱们就能够把一张图片看作是一个图构造,从而利用图神经网络来进行剖析和解决。

如何应用 PyG 中的 ToSLIC 函数?

PyG 中的 ToSLIC 模块实现了一种基于 SLIC(Simple Linear Iterative Clustering)算法的超像素宰割办法。该算法将图像划分为若干个相邻的块,每个块具备雷同的色彩或者灰度级别。ToSLIC 模块应用 PyTorch 实现,并能够间接集成到 PyG 的 GNN 模型中,用于图像宰割工作。

ToSLIC 模块次要由以下几个步骤组成:

  1. 图像预处理:将原始图像转换为 LAB 色调空间,并对其进行归一化解决。
  2. 超像素初始化:在图像上随机选取若干个像素作为超像素核心,依据这些核心像素计算每个像素与哪个超像素最近,将其纳入该超像素。
  3. 超像素迭代:反复进行以下两个步骤,直到收敛:

    a. 计算每个超像素的中心点,并更新其地位;

    b. 对每个像素,计算其与每个超像素中心点之间的间隔,并将其纳入最近的超像素中。

  4. 超像素合并:依据超像素之间的间隔,将相邻的超像素合并成一个更大的超像素。

进一步说,ToSLIC 函数是 PyG 中提供的一个图转换函数,它应用了 skimage 库中的 slic 算法来实现图片到超像素图的转换。ToSLIC 函数承受一个 torch.Tensor 类型的图片作为输出,并返回一个 torch_geometric.data.Data 类型的数据对象作为输入。输入对象蕴含以下属性:

  • x: 一个二维张量,示意每个超像素节点的特征向量。默认状况下,特征向量是每个超像素节点在 RGB 空间下的均匀色彩值。
  • pos: 一个二维张量,示意每个超像素节点在原始图片上的地位坐标。
  • seg: 一个二维张量(可选),示意原始图片上每个像素所属于哪个超像素节点。
  • img: 一个四维张量(可选),示意原始图片。

ToSLIC 函数还能够承受一些额定参数来调整 slic 算法和输入对象:

  • add_seg: 一个布尔值(默认为 False),示意是否在输入对象中增加 seg 属性。
  • add_img: 一个布尔值(默认为 False),示意是否在输入对象中增加 img 属性。
  • **kwargs: 其余参数,用于调整 slic 算法。具体参见 skimage.segmentation.slic 文档。

上面给出一个简略的例子:

from torchvision.datasets import MNIST
import torchvision.transforms as T
from torch_geometric.transforms import ToSLIC

transform = T.Compose([T.ToTensor(),
ToSLIC(n_segments=75, add_seg=True)
])
dataset = MNIST('/tmp/MNIST', download=True, transform=transform)
data = dataset[0] # data is a Data object with x, pos and seg attributes

这段代码首先从 torchvision 库中加载了 MNIST 数据集,并定义了一个组合变换 transform。transform 蕴含两个步骤:第一步是将 PIL.Image 类型的图片转换为 torch.Tensor 类型;第二步是将 torch.Tensor 类型的图片转换为 Data 类型,并指定要生成 75 个超像素节点,并在输入对象中增加 seg 属性。而后咱们从数据集中取出第一张图片,并利用 transform 变换失去 data 对象。

利用

在训练 GNN 模型时,能够应用这个数据集作为输出数据,并将其转换为图形数据格式。例如,能够将图像中的超像素视为节点,超像素之间的邻接关系视为边。在 PyG 中,能够应用 Data 类来示意图形数据,其中包含节点特色、边索引和边特色等信息。

上面是一个应用 ToSLIC 模块进行图像宰割的例子:

import torch
import torch.nn.functional as F
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.data import DataLoader

class GCN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__(aggr='add')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCN(3, 16)
        self.conv2 = GCN(16, 32)
        self.conv3 = GCN(32, 64)
        self.lin = torch.nn.Linear(64, 10)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

dataset = MNISTSuperpixels(root='~/datasets/MNIST', train=True, transform=ToSLIC())
loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    model.train()
    for data in loader:
        optimizer.zero_grad()
        out = model(data.x.float(), data.edge_index)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x.float(), data.edge_index)
        pred = out.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
    acc = correct / len(dataset)
    print(f'Epoch: {epoch}, Accuracy: {acc:.4f}')

在这个例子中,咱们应用了一个蕴含三个 GCN 层的模型来对超像素图进行分类,其中输出节点特色的维度为 3,代表超像素的 LAB 色彩空间的三个通道。咱们应用全局池化层对每个超像素的特色进行聚合,最终输入一个 10 维向量,代表对应 MNIST 数字图像的类别概率分布。训练过程中,咱们应用 ToSLIC 将原始图像转换为超像素图,并应用 DataLoader 将其加载到模型中进行训练。

GCN 类中,咱们重写了 message()forward()办法。在 message() 中,咱们将每个节点的特色依照邻接矩阵中的边权重进行加权均匀,以获取该节点的街坊节点特色的信息。在 forward() 中,咱们首先对邻接矩阵增加自环,而后对每个节点的特色进行线性变换。接下来,咱们计算每个边的权重,以便在 message() 中进行加权均匀。最初,咱们调用 propagate() 办法来执行信息传递操作。

Net 类中,咱们定义了一个蕴含三个 GCN 层和一个全连贯层的模型,并在每个 GCN 层后利用了一个 ReLU 激活函数。在模型的最初一层,咱们应用 global_mean_pool() 对所有超像素的特色进行全局均匀池化,以获取整个图像的特色示意。

训练过程中,咱们应用 Adam 优化器来最小化穿插熵损失,并在每个 epoch 完结时计算模型的准确率。因为咱们应用了 ToSLIC 模块对原始图像进行了超像素宰割,因而咱们能够将每个超像素视为一个节点,并应用图卷积神经网络来对其进行分类。

在理论利用中,ToSLIC 模块能够与其余 PyG 中的模块联合应用,例如 SAGEConv、GATConv 等,以实现更简单的图卷积神经网络。此外,咱们还能够应用 ToSLIC 模块将图像宰割利用于其余工作,例如指标检测、图像生成等。

本文参加了 SegmentFault 思否写作挑战赛,欢送正在浏览的你也退出。

正文完
 0