关于python:MMSegmentation自定义数据集

49次阅读

共计 15886 个字符,预计需要花费 40 分钟才能阅读完成。

📕前言

该文章次要是简述一下本人为了实现极市平台赛事过程中,应用 MMSegmentation 语义宰割开源库的心得。

在学习一个新的工具之前,肯定须要明确本人是用工具实现什么指标,而不是为了学工具而学,一旦有了目标会给你所作的事件带来意义,然而也要防止急于求成(人总是喜爱简略间接的事件,然而只有真正拉扯过肌肉才会成长),所以保持不上来的时候,只有明确这是你的大脑退缩了,但你依然想学。💪

\(\quad \)

🌳文章构造

本文章将从一下几个方面介绍如何上手 MMsegmentation,并用 MMDeploy 实现简略的部署:

  • 装置 MMSegmentation
  • MMSegmentation 的文件构造
  • MMSegmentation 的配置文件(外围)
  • 如何在 MMSegmentation 中自定义数据集
  • 训练和测试

我强烈建议配合官网文档一起学习:https://mmsegmentation.readthedocs.io/zh_CN/latest/index.html
PS:如此良心的开源库还带中文文档!😭

\(\quad \)

📝注释

装置 MMSegmentation

环境筹备(可选,但举荐)

个别咱们为了环境隔离用 Miniconda(Anaconda)创立一个新的 python 环境,但在某些状况下也能够不必,取决于你的习惯。

从官方网站下载并装置 Miniconda & 创立一个 conda 环境,并激活:

conda create --name openmmlab python=3.8 -y
conda activate openmmlab

\(\quad \)

装置库

  1. 依据官网装置 pytorch,当初更新到 2.0 了,然而举荐装置之前的版本(能够点击页面中上面红框的链接,授之以渔),也能够间接点击 install previous versions of PyTorch(授之以鱼)

    gpu 版本(要对应本人的 cuda 版本,pip 和 conda 二选一)

    # pip 装置
    # CUDA 11.1 
    pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
    
    # 或者
    # conda 装置
    # CUDA 11.3
    conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 
    

    cpu 版本(看 MMSegmentation 的官网文档吧)

\(\quad \)

  1. 装置 MMCV(OpenMMLab 其余许多库都有这个依赖)
    举荐装置形式 mim,更多形式看 MMCV

    pip install -U openmim
    mim install mmengine
    mim install "mmcv>=2.0.0"

\(\quad \)

  1. 装置 MMsegmentation
    a. 形式一:源码装置,这个比拟容易前期开发,因为可能间接批改并应用源码(本教程装置形式)

    git clone -b main https://github.com/open-mmlab/mmsegmentation.git
    cd mmsegmentation
    pip install -v -e .
    # '-v' 示意具体模式,更多的输入
    # '-e' 示意以可编辑模式装置工程,# 因而对代码所做的任何批改都失效,无需重新安装

    b. 形式二:作为依赖库装置

    pip install "mmsegmentation>=1.0.0"

\(\quad \)

验证装置是否胜利

源码装置测验形式

cd mmsegmentation
python demo/image_demo.py demo/demo.png \\
configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \\
pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \\
--device cuda:0 --out-file result.jpg

您将在以后文件夹中看到一个新图像 result.jpg,其中所有指标都笼罩了宰割 mask

其余更多装置形式见官网文档:https://mmsegmentation.readthedocs.io/zh_CN/latest/get_starte…

\(\quad \)

MMSegmentation 的文件构造

接下来咱们略微看一下 MMsegmentation 的文件构造目录

mmsegmentation
- configs # ** 配置文件,是该库的外围 **
    - _base_ # 根底模块文件,** 但实质上还是配置文件 **,包含数据集,模型,训练配置
        - datasets
        - models
        - schedules    
    - else model config # 除了 _base_ 之外,其余都是通过利用 _base_ 中定义好的模块进行组合的模型文件


- mmseg # ** 这是库外围的实现,下面配置文件的模块都在这里定义 **
    - datasets
    - models

- tools # 这里包含训练、测试、转 onnx 等写好了的工具,间接调用即可
    - train.py
    - test.py

- data # 搁置数据集

- demo # 提供了几个小 demo(可不论)- docker # 容器配置(可不论)- docs # 各种阐明文档(可不论)- projects #(可不论)- requirements #(可不论)- tests #(可不论)

从下面能够看出,其实 MMSegmentation 做了很好的封装,如果只是应用,那是非常容易上手的。

config/_base_ 和 mmseg 中的 datasets、models 等文件有什么区别呢?
上面用 ade 数据集举一个例子(大抵看一下差别,不须要弄懂):

  • config/_base_/datasets/ade20k.py

    # dataset settings
    dataset_type = 'ADE20KDataset'
    data_root = 'data/ade/ADEChallengeData2016'
    crop_size = (512, 512)
    train_pipeline = [dict(type='LoadImageFromFile'),
      dict(type='LoadAnnotations', reduce_zero_label=True),
      dict(
          type='RandomResize',
          scale=(2048, 512),
          ratio_range=(0.5, 2.0),
          keep_ratio=True),
      dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
      dict(type='RandomFlip', prob=0.5),
      dict(type='PhotoMetricDistortion'),
      dict(type='PackSegInputs')
    ]
    test_pipeline = [dict(type='LoadImageFromFile'),
      dict(type='Resize', scale=(2048, 512), keep_ratio=True),
      # add loading annotation after ``Resize`` because ground truth
      # does not need to do resize data transform
      dict(type='LoadAnnotations', reduce_zero_label=True),
      dict(type='PackSegInputs')
    ]
    img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    tta_pipeline = [dict(type='LoadImageFromFile', backend_args=None),
      dict(
          type='TestTimeAug',
          transforms=[
              [dict(type='Resize', scale_factor=r, keep_ratio=True)
                  for r in img_ratios
              ],
              [dict(type='RandomFlip', prob=0., direction='horizontal'),
                  dict(type='RandomFlip', prob=1., direction='horizontal')
              ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
          ])
    ]
    train_dataloader = dict(
      batch_size=4,
      num_workers=4,
      persistent_workers=True,
      sampler=dict(type='InfiniteSampler', shuffle=True),
      dataset=dict(
          type=dataset_type,
          data_root=data_root,
          data_prefix=dict(img_path='images/training', seg_map_path='annotations/training'),
          pipeline=train_pipeline))
    val_dataloader = dict(
      batch_size=1,
      num_workers=4,
      persistent_workers=True,
      sampler=dict(type='DefaultSampler', shuffle=False),
      dataset=dict(
          type=dataset_type,
          data_root=data_root,
          data_prefix=dict(
              img_path='images/validation',
              seg_map_path='annotations/validation'),
          pipeline=test_pipeline))
    test_dataloader = val_dataloader
    
    val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
    test_evaluator = val_evaluator
    
  • mmseg/datasets/ade.py

    # Copyright (c) OpenMMLab. All rights reserved.
    from mmseg.registry import DATASETS
    from .basesegdataset import BaseSegDataset
    
    
    @DATASETS.register_module()
    class ADE20KDataset(BaseSegDataset):
      """ADE20K dataset.
    
      In segmentation map annotation for ADE20K, 0 stands for background, which
      is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
      The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
      '.png'.
      """
      METAINFO = dict(
          classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
                   'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk',
                   'person', 'earth', 'door', 'table', 'mountain', 'plant',
                   'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
                   'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
                   'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
                   'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
                   'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
                   'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
                   'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
                   'screen door', 'stairway', 'river', 'bridge', 'bookcase',
                   'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
                   'bench', 'countertop', 'stove', 'palm', 'kitchen island',
                   'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
                   'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
                   'chandelier', 'awning', 'streetlight', 'booth',
                   'television receiver', 'airplane', 'dirt track', 'apparel',
                   'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
                   'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
                   'conveyer belt', 'canopy', 'washer', 'plaything',
                   'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
                   'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
                   'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
                   'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
                   'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
                   'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
                   'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
                   'clock', 'flag'),
          palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
                   [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
                   [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
                   [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
                   [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
                   [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
                   [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
                   [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
                   [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
                   [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
                   [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
                   [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
                   [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
                   [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
                   [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
                   [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
                   [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
                   [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
                   [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
                   [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
                   [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
                   [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
                   [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
                   [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
                   [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
                   [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
                   [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
                   [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
                   [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
                   [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
                   [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
                   [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
                   [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
                   [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
                   [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
                   [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
                   [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
                   [102, 255, 0], [92, 0, 255]])
    
      def __init__(self,
                   img_suffix='.jpg',
                   seg_map_suffix='.png',
                   reduce_zero_label=True,
                   **kwargs) -> None:
          super().__init__(
              img_suffix=img_suffix,
              seg_map_suffix=seg_map_suffix,
              reduce_zero_label=reduce_zero_label,
              **kwargs)

\(\quad \)

MMSegmentation 的 config 配置文件(外围)

在应用 MMSegmentation 中的模型进行训练和测试的时候就可能看出 config 配置文件的重要性

在单 GPU 上训练和测试

在单 GPU 上训练

tools/train.py 文件提供了在单 GPU 上部署训练任务的办法。

根底用法如下:

python tools/train.py  ${配置文件} [可选参数]
# 要害参数:#    config.py # 必须提供撇脂文件
#     --work-dir ${工作门路} # 从新指定工作门路

更多其余参数详情

举例 pspnet

python tools/train.py \\
configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \\
--work-dir logs/pspnet

configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
该配置文件调用了_base_中定义的 models、dataset、schedules 等配置文件,这种模块化形式就很容易通过重新组合来调整整体模型。

_base_ = [
    '../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size)
model = dict(data_preprocessor=data_preprocessor)

其中每个模块的配置文件细节见:https://mmsegmentation.readthedocs.io/zh_CN/latest/user_guide…

\(\quad \)

如何在 MMSegmentation 中自定义数据集

这应该是大家比较关心的局部,重点是。咱们首先看看官网对于一些罕用的数据集的文件目录是怎么样的(拿 CHASE_DB1 数据集(二类别语义宰割)举个例子):

mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│   ├── CHASE_DB1
│   │   ├── images
│   │   │   ├── training
│   │   │   ├── validation
│   │   ├── annotations
│   │   │   ├── training
│   │   │   ├── validation

可见其中蕴含:

  • annotations:语义宰割的实在 mark label
  • images:待宰割的 RGB 图像

自定义数据集

依据以上构造咱们能够构建本人的数据集,这里我次要是利用极市平台 写字楼消防门梗塞辨认 二类别语义宰割工作的数据集,其中 门的 label 是 1,背景 label 是 0

并且将其划分为训练集和验证集,在 mmsegmentation/data 中增加以下文件:

mmsegmentation
|   data
|   | xiaofang
│   │   ├── images
│   │   │   ├── training
│   │   │   ├── validation
│   │   ├── annotations
│   │   │   ├── training
│   │   │   ├── validation

增加数据集模块

  1. mmsegmentation/mmseg/datasets 中增加一个 xiaofang.py 定义本人的数据类 XiaoFangDataset
    xiaofang.py

    # Copyright (c) OpenMMLab. All rights reserved.
    
    from .builder import DATASETS
    from .custom import CustomDataset
    
    
    @DATASETS.register_module()
    class XiaoFangDataset(CustomDataset):
        CLASSES = ('background', 'door')
    
        PALETTE = [[120, 120, 120], [6, 230, 230]]
    
        def __init__(self, **kwargs):
            super(XiaoFangDataset, self).__init__(
                img_suffix='.jpg', # 留神门路
                seg_map_suffix='.png',
                reduce_zero_label=False,
                **kwargs)
            assert self.file_client.exists(self.img_dir)
    
  2. mmsegmentation/mmseg/datasets/__init__.py 中申明本人定义的数据类XiaoFangDataset

    # Copyright (c) OpenMMLab. All rights reserved.
    from .ade import ADE20KDataset
    from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
    from .chase_db1 import ChaseDB1Dataset
    from .cityscapes import CityscapesDataset
    from .coco_stuff import COCOStuffDataset
    from .custom import CustomDataset
    from .dark_zurich import DarkZurichDataset
    from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
                                   RepeatDataset)
    from .drive import DRIVEDataset
    from .face import FaceOccludedDataset
    from .hrf import HRFDataset
    from .isaid import iSAIDDataset
    from .isprs import ISPRSDataset
    from .loveda import LoveDADataset
    from .night_driving import NightDrivingDataset
    from .pascal_context import PascalContextDataset, PascalContextDataset59
    from .potsdam import PotsdamDataset
    from .stare import STAREDataset
    from .voc import PascalVOCDataset
    from .xiaofang import XiaoFangDataset
    
    __all__ = [
        'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
        'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
        'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
        'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
        'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
        'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
        'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'FaceOccludedDataset',
        'XiaoFangDataset'
    ]
    
  3. mmsegmentation/mmseg/core/evaluation/class_names.py 中申明本人的标签类别名称

    def xiaofang_classes():
        return ['background','door']
  4. mmsegmentation/configs/_base_/datasets 中增加本人数据集的配置文件 xiaofang.py

    # dataset settings
    dataset_type = 'XiaoFangDataset' # 数据类名称
    data_root = 'data/xiaofang' # 数据寄存地位
    img_norm_cfg = dict(mean=[120.4652, 123.1624, 124.3220], std=[63.5322, 60.6218, 59.2707], to_rgb=True)
    crop_size = (512, 512)
    train_pipeline = [dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(type='Resize', img_scale=(1920, 1080), ratio_range=(0.5, 2.0)),
        dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
        dict(type='RandomFlip', prob=0.5),
        dict(type='PhotoMetricDistortion'),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
        dict(type='DefaultFormatBundle'),
        dict(type='Collect', keys=['img', 'gt_semantic_seg']),
    ]
    test_pipeline = [dict(type='LoadImageFromFile'),
        dict(
            type='MultiScaleFlipAug',
            # img_scale=(2048, 512),
            img_scale=(1920, 1080),
            # img_scale=(960, 540),
            # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
            flip=False,
            transforms=[dict(type='Resize', keep_ratio=True),
                dict(type='RandomFlip'),
                dict(type='Normalize', **img_norm_cfg),
                dict(type='ImageToTensor', keys=['img']),
                dict(type='Collect', keys=['img']),
            ])
    ]
    data = dict(
        samples_per_gpu=4,
        workers_per_gpu=4,
        train=dict(
            type=dataset_type,
            data_root=data_root,
            img_dir='images/training',
            ann_dir='annotations/training',
            pipeline=train_pipeline),
        val=dict(
            type=dataset_type,
            data_root=data_root,
            img_dir='images/validation',
            ann_dir='annotations/validation',
            pipeline=test_pipeline),
        test=dict(
            type=dataset_type,
            data_root=data_root,
            img_dir='images/validation',
            ann_dir='annotations/validation',
            pipeline=test_pipeline))
    

其中配置文件参数的细节含意仍见:https://mmsegmentation.readthedocs.io/zh_CN/latest/user_guide…

\(\quad \)

训练和测试

在实现了数据集配置后,就须要搭建整体模型的配置文件即可,MMSegmentation 提供了许多开源模型(上面是一部分,更多详情):

个别须要依据本人的 GPU 显存大小抉择模型,点击下面的 config 可能看到对应模型所须要的显存大小,如这里咱们举例抉择一个 STDC 模型:

  1. 批改残缺配置文件:在 mmsegmentation/configs/stdc 中增加上本人的模型 stdc2_512x1024_10k_xiaofang.py

    _base_ = ['../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py', '../_base_/datasets/xiaofang.py']
    
    # checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth'  # noqa
    
    
    norm_cfg = dict(type='BN', requires_grad=True)
    model = dict(
        type='EncoderDecoder',
        pretrained=None,
        backbone=dict(
            type='STDCContextPathNet',
            backbone_cfg=dict(# init_cfg=dict(type='Pretrained', checkpoint=checkpoint),
                type='STDCNet',
                stdc_type='STDCNet2',
                in_channels=3,
                channels=(32, 64, 256, 512, 1024),
                bottleneck_type='cat',
                num_convs=4,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU'),
                with_final_conv=False),
            last_in_channels=(1024, 512),
            out_channels=128,
            ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)),
        decode_head=dict(
            type='FCNHead',
            in_channels=256,
            channels=256,
            num_convs=1,
            num_classes=2,
            in_index=3,
            concat_input=False,
            dropout_ratio=0.1,
            norm_cfg=norm_cfg,
            align_corners=False,
            sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
            loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
        auxiliary_head=[
            dict(
                type='FCNHead',
                in_channels=128,
                channels=64,
                num_convs=1,
                num_classes=2,
                in_index=2,
                norm_cfg=norm_cfg,
                concat_input=False,
                align_corners=False,
                sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
                loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
            dict(
                type='FCNHead',
                in_channels=128,
                channels=64,
                num_convs=1,
                num_classes=2,
                in_index=1,
                norm_cfg=norm_cfg,
                concat_input=False,
                align_corners=False,
                sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
                loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
            dict(
                type='STDCHead',
                in_channels=256,
                channels=64,
                num_convs=1,
                num_classes=2,
                boundary_threshold=0.1,
                in_index=0,
                norm_cfg=norm_cfg,
                concat_input=False,
                align_corners=False,
                loss_decode=[
                    dict(
                        type='CrossEntropyLoss',
                        loss_name='loss_ce',
                        use_sigmoid=True,
                        loss_weight=1.0),
                    dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
                ]),
        ],
        # model training and testing settings
        train_cfg=dict(),
        test_cfg=dict(mode='whole'))
    
    
    checkpoint_config = dict(# 设置检查点钩子 (checkpoint hook) 的配置文件。执行时请参考 https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py。by_epoch=False,
        save_last=False,  # 是否依照每个 epoch 去算 runner。interval=2000)  # 保留的距离
    
    evaluation = dict(interval=1000, metric='mIoU', pre_eval=True)
    runner = dict(type='IterBasedRunner', max_iters=10000)
    log_config = dict(
        interval=10,
        hooks=[dict(type='TextLoggerHook', by_epoch=False),
            # dict(type='TensorboardLoggerHook')
            # dict(type='PaviLoggerHook') # for internal services
        ])
    lr_config = dict(warmup='linear', warmup_iters=1000)
  2. 训练

    python tools/train.py \\
    configs/stdc/stdc2_512x1024_10k_xiaofang.py \\
    --work-dir logs/stdc2
  3. 测试后果:MIoU=0.9225,上面别离是 RGB 图像、实在 Label、STDC 模型输入

👏本文参加了 SegmentFault 思否写作挑战赛,欢送正在浏览的你也退出。

正文完
 0