详解 PyG 中的 ToSLIC 变换
PyG 是一个基于 PyTorch 的图神经网络库,提供了丰盛的数据处理、图转换和图模型的性能。本文将介绍 PyG 中的一个图转换函数 ToSLIC,它能够将一张图片转换为一个超像素图,并生成相应的数据对象。
前言
PyG 是一个开源的 Python 库,用于深度学习工作中的图神经网络(GNN)建模和训练。该库包含多个 GNN 模型和与图相干的数据结构和算法。在本篇文章中,我将介绍 PyG 中的 ToSLIC 模块,它是一个用于图像宰割的超像素宰割算法。
什么是超像素图?
图像宰割是计算机视觉畛域的一个重要问题,它的指标是将图像分成若干个类似的区域,每个区域具备肯定的语义信息。图像宰割在许多畛域都有利用,如医学影像、主动驾驶、图像检索等。
超像素宰割是一种常见的图像宰割办法,它将图像中的像素划分为若干个类似的区域,这些区域被称为超像素。与像素相比,超像素更具备代表性和可解释性,并且能够升高图像宰割的复杂度。
超像素图是一种对图片进行宰割的办法,将图片中类似的像素聚合成一个个小区域,称为超像素。每个超像素能够看作是图片中的一个节点,它具备肯定的特色(如色彩、地位等),并与其余超像素有肯定的关系(如邻接、间隔等)。这样,咱们就能够把一张图片看作是一个图构造,从而利用图神经网络来进行剖析和解决。
如何应用 PyG 中的 ToSLIC 函数?
PyG 中的 ToSLIC 模块实现了一种基于 SLIC(Simple Linear Iterative Clustering)算法的超像素宰割办法。该算法将图像划分为若干个相邻的块,每个块具备雷同的色彩或者灰度级别。ToSLIC 模块应用 PyTorch 实现,并能够间接集成到 PyG 的 GNN 模型中,用于图像宰割工作。
ToSLIC 模块次要由以下几个步骤组成:
- 图像预处理:将原始图像转换为 LAB 色调空间,并对其进行归一化解决。
- 超像素初始化:在图像上随机选取若干个像素作为超像素核心,依据这些核心像素计算每个像素与哪个超像素最近,将其纳入该超像素。
-
超像素迭代:反复进行以下两个步骤,直到收敛:
a. 计算每个超像素的中心点,并更新其地位;
b. 对每个像素,计算其与每个超像素中心点之间的间隔,并将其纳入最近的超像素中。
- 超像素合并:依据超像素之间的间隔,将相邻的超像素合并成一个更大的超像素。
进一步说,ToSLIC 函数是 PyG 中提供的一个图转换函数,它应用了 skimage 库中的 slic 算法来实现图片到超像素图的转换。ToSLIC 函数承受一个 torch.Tensor 类型的图片作为输出,并返回一个 torch_geometric.data.Data 类型的数据对象作为输入。输入对象蕴含以下属性:
- x: 一个二维张量,示意每个超像素节点的特征向量。默认状况下,特征向量是每个超像素节点在 RGB 空间下的均匀色彩值。
- pos: 一个二维张量,示意每个超像素节点在原始图片上的地位坐标。
- seg: 一个二维张量(可选),示意原始图片上每个像素所属于哪个超像素节点。
- img: 一个四维张量(可选),示意原始图片。
ToSLIC 函数还能够承受一些额定参数来调整 slic 算法和输入对象:
- add_seg: 一个布尔值(默认为 False),示意是否在输入对象中增加 seg 属性。
- add_img: 一个布尔值(默认为 False),示意是否在输入对象中增加 img 属性。
- **kwargs: 其余参数,用于调整 slic 算法。具体参见 skimage.segmentation.slic 文档。
上面给出一个简略的例子:
from torchvision.datasets import MNIST
import torchvision.transforms as T
from torch_geometric.transforms import ToSLIC
transform = T.Compose([T.ToTensor(),
ToSLIC(n_segments=75, add_seg=True)
])
dataset = MNIST('/tmp/MNIST', download=True, transform=transform)
data = dataset[0] # data is a Data object with x, pos and seg attributes
这段代码首先从 torchvision 库中加载了 MNIST 数据集,并定义了一个组合变换 transform。transform 蕴含两个步骤:第一步是将 PIL.Image 类型的图片转换为 torch.Tensor 类型;第二步是将 torch.Tensor 类型的图片转换为 Data 类型,并指定要生成 75 个超像素节点,并在输入对象中增加 seg 属性。而后咱们从数据集中取出第一张图片,并利用 transform 变换失去 data 对象。
利用
在训练 GNN 模型时,能够应用这个数据集作为输出数据,并将其转换为图形数据格式。例如,能够将图像中的超像素视为节点,超像素之间的邻接关系视为边。在 PyG 中,能够应用 Data
类来示意图形数据,其中包含节点特色、边索引和边特色等信息。
上面是一个应用 ToSLIC 模块进行图像宰割的例子:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.data import DataLoader
class GCN(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCN, self).__init__(aggr='add')
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCN(3, 16)
self.conv2 = GCN(16, 32)
self.conv3 = GCN(32, 64)
self.lin = torch.nn.Linear(64, 10)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = self.conv3(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
dataset = MNISTSuperpixels(root='~/datasets/MNIST', train=True, transform=ToSLIC())
loader = DataLoader(dataset, batch_size=64, shuffle=True)
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
model.train()
for data in loader:
optimizer.zero_grad()
out = model(data.x.float(), data.edge_index)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
model.eval()
correct = 0
for data in loader:
out = model(data.x.float(), data.edge_index)
pred = out.argmax(dim=1)
correct += pred.eq(data.y).sum().item()
acc = correct / len(dataset)
print(f'Epoch: {epoch}, Accuracy: {acc:.4f}')
在这个例子中,咱们应用了一个蕴含三个 GCN 层的模型来对超像素图进行分类,其中输出节点特色的维度为 3,代表超像素的 LAB 色彩空间的三个通道。咱们应用全局池化层对每个超像素的特色进行聚合,最终输入一个 10 维向量,代表对应 MNIST 数字图像的类别概率分布。训练过程中,咱们应用 ToSLIC
将原始图像转换为超像素图,并应用 DataLoader
将其加载到模型中进行训练。
在 GCN
类中,咱们重写了 message()
和forward()
办法。在 message()
中,咱们将每个节点的特色依照邻接矩阵中的边权重进行加权均匀,以获取该节点的街坊节点特色的信息。在 forward()
中,咱们首先对邻接矩阵增加自环,而后对每个节点的特色进行线性变换。接下来,咱们计算每个边的权重,以便在 message()
中进行加权均匀。最初,咱们调用 propagate()
办法来执行信息传递操作。
在 Net
类中,咱们定义了一个蕴含三个 GCN
层和一个全连贯层的模型,并在每个 GCN
层后利用了一个 ReLU 激活函数。在模型的最初一层,咱们应用 global_mean_pool()
对所有超像素的特色进行全局均匀池化,以获取整个图像的特色示意。
训练过程中,咱们应用 Adam 优化器来最小化穿插熵损失,并在每个 epoch 完结时计算模型的准确率。因为咱们应用了 ToSLIC 模块对原始图像进行了超像素宰割,因而咱们能够将每个超像素视为一个节点,并应用图卷积神经网络来对其进行分类。
在理论利用中,ToSLIC 模块能够与其余 PyG 中的模块联合应用,例如 SAGEConv、GATConv 等,以实现更简单的图卷积神经网络。此外,咱们还能够应用 ToSLIC 模块将图像宰割利用于其余工作,例如指标检测、图像生成等。
本文参加了 SegmentFault 思否写作挑战赛,欢送正在浏览的你也退出。