MMSegmentation V0.27.0训练与推理自己的数据集(二)

news/2024/5/10 0:29:58/文章来源:https://blog.csdn.net/qq_41627642/article/details/126479513

1、官方模型转换MMSegmentation风格

如果你想自己转换关键字使用官方存储库的预训练模型,我们还提供了一个脚本swin2mmseg.py在tools directory ,将模型的关键字从官方的repo转换为MMSegmentation风格。

python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth

这个脚本从PRETRAIN_PATH转换模型,并将转换后的模型存储在STORE_PATH中。
在我们的默认设置中,预训练的模型及其对应的原始模型模型可以定义如下:
在这里插入图片描述
在这里插入图片描述

2、下载ADK20的模型

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340-593b0e13.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650-762e2178.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k_20220318_015320-48d180dd.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k_20220318_091743-9ba68901.pth

3、下载Swin Transform预训练模型

#tinyhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth#small
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth#big
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_20220317-e9b98025.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth#large
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220412-aeecf2aa.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth

4、构造ADK20结构的数据目录

ADE20k拥有超过25,000张图像(20ktrain,2k val,3ktest),这些图像用开放字典标签集密集注释。对于2017 Places Challenge 2,选择了覆盖89%所有像素的100个thing和50个stuff类别。
一共150个类别。
在这里插入图片描述

Idx	Ratio	Train	Val	Name
1	0.1576	11664	1172	wall
2	0.1072	6046	612	building, edifice
3	0.0878	8265	796	sky
4	0.0621	9336	917	floor, flooring
5	0.0480	6678	641	tree
6	0.0450	6604	643	ceiling
7	0.0398	4023	408	road, route
8	0.0231	1906	199	bed 
9	0.0198	4688	460	windowpane, window 
10	0.0183	2423	225	grass
11	0.0181	2874	294	cabinet
12	0.0166	3068	310	sidewalk, pavement
13	0.0160	5075	526	person, individual, someone, somebody, mortal, soul
14	0.0151	1804	190	earth, ground
15	0.0118	6666	796	door, double door
16	0.0110	4269	411	table
17	0.0109	1691	160	mountain, mount
18	0.0104	3999	441	plant, flora, plant life
19	0.0104	2149	217	curtain, drape, drapery, mantle, pall
20	0.0103	3261	318	chair
21	0.0098	3164	306	car, auto, automobile, machine, motorcar
22 	0.0074	709	75	water
23	0.0067	3296	315	painting, picture
24 	0.0065	1191	106	sofa, couch, lounge
25 	0.0061	1516	162	shelf
26 	0.0060	667	69	house
27 	0.0053	651	57	sea
28	0.0052	1847	224	mirror
29	0.0046	1158	128	rug, carpet, carpeting
30	0.0044	480	44	field
31	0.0044	1172	98	armchair
32	0.0044	1292	184	seat
33	0.0033	1386	138	fence, fencing
34	0.0031	698	61	desk
35	0.0030	781	73	rock, stone
36	0.0027	380	43	wardrobe, closet, press
37	0.0026	3089	302	lamp
38	0.0024	404	37	bathtub, bathing tub, bath, tub
39	0.0024	804	99	railing, rail
40	0.0023	1453	153	cushion
41	0.0023	411	37	base, pedestal, stand
42	0.0022	1440	162	box
43	0.0022	800	77	column, pillar
44	0.0020	2650	298	signboard, sign
45	0.0019	549	46	chest of drawers, chest, bureau, dresser
46	0.0019	367	36	counter
47	0.0018	311	30	sand
48	0.0018	1181	122	sink
49	0.0018	287	23	skyscraper
50	0.0018	468	38	fireplace, hearth, open fireplace
51	0.0018	402	43	refrigerator, icebox
52	0.0018	130	12	grandstand, covered stand
53	0.0018	561	64	path
54	0.0017	880	102	stairs, steps
55	0.0017	86	12	runway
56	0.0017	172	11	case, display case, showcase, vitrine
57	0.0017	198	18	pool table, billiard table, snooker table
58	0.0017	930	109	pillow
59	0.0015	139	18	screen door, screen
60	0.0015	564	52	stairway, staircase
61	0.0015	320	26	river
62	0.0015	261	29	bridge, span
63	0.0014	275	22	bookcase
64	0.0014	335	60	blind, screen
65	0.0014	792	75	coffee table, cocktail table
66	0.0014	395	49	toilet, can, commode, crapper, pot, potty, stool, throne
67	0.0014	1309	138	flower
68	0.0013	1112	113	book
69	0.0013	266	27	hill
70	0.0013	659	66	bench
71	0.0012	331	31	countertop
72	0.0012	531	56	stove, kitchen stove, range, kitchen range, cooking stove
73	0.0012	369	36	palm, palm tree
74	0.0012	144	9	kitchen island
75	0.0011	265	29	computer, computing machine, computing device, data processor, electronic computer, information processing system
76	0.0010	324	33	swivel chair
77	0.0009	304	27	boat
78	0.0009	170	20	bar
79	0.0009	68	6	arcade machine
80	0.0009	65	8	hovel, hut, hutch, shack, shanty
81	0.0009	248	25	bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle
82	0.0008	492	49	towel
83	0.0008	2510	269	light, light source
84	0.0008	440	39	truck, motortruck
85	0.0008	147	18	tower
86	0.0008	583	56	chandelier, pendant, pendent
87	0.0007	533	61	awning, sunshade, sunblind
88	0.0007	1989	239	streetlight, street lamp
89	0.0007	71	5	booth, cubicle, stall, kiosk
90	0.0007	618	53	television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box
91	0.0007	135	12	airplane, aeroplane, plane
92	0.0007	83	5	dirt track
93	0.0007	178	17	apparel, wearing apparel, dress, clothes
94	0.0006	1003	104	pole
95	0.0006	182	12	land, ground, soil
96	0.0006	452	50	bannister, banister, balustrade, balusters, handrail
97	0.0006	42	6	escalator, moving staircase, moving stairway
98	0.0006	307	31	ottoman, pouf, pouffe, puff, hassock
99	0.0006	965	114	bottle
100	0.0006	117	13	buffet, counter, sideboard
101	0.0006	354	35	poster, posting, placard, notice, bill, card
102	0.0006	108	9	stage
103	0.0006	557	55	van
104	0.0006	52	4	ship
105	0.0005	99	5	fountain
106	0.0005	57	4	conveyer belt, conveyor belt, conveyer, conveyor, transporter
107	0.0005	292	31	canopy
108	0.0005	77	9	washer, automatic washer, washing machine
109	0.0005	340	38	plaything, toy
110	0.0005	66	3	swimming pool, swimming bath, natatorium
111	0.0005	465	49	stool
112	0.0005	50	4	barrel, cask
113	0.0005	622	75	basket, handbasket
114	0.0005	80	9	waterfall, falls
115	0.0005	59	3	tent, collapsible shelter
116	0.0005	531	72	bag
117	0.0005	282	30	minibike, motorbike
118	0.0005	73	7	cradle
119	0.0005	435	44	oven
120	0.0005	136	25	ball
121	0.0005	116	24	food, solid food
122	0.0004	266	31	step, stair
123	0.0004	58	12	tank, storage tank
124	0.0004	418	83	trade name, brand name, brand, marque
125	0.0004	319	43	microwave, microwave oven
126	0.0004	1193	139	pot, flowerpot
127	0.0004	97	23	animal, animate being, beast, brute, creature, fauna
128	0.0004	347	36	bicycle, bike, wheel, cycle 
129	0.0004	52	5	lake
130	0.0004	246	22	dishwasher, dish washer, dishwashing machine
131	0.0004	108	13	screen, silver screen, projection screen
132	0.0004	201	30	blanket, cover
133	0.0004	285	21	sculpture
134	0.0004	268	27	hood, exhaust hood
135	0.0003	1020	108	sconce
136	0.0003	1282	122	vase
137	0.0003	528	65	traffic light, traffic signal, stoplight
138	0.0003	453	57	tray
139	0.0003	671	100	ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
140	0.0003	397	44	fan
141	0.0003	92	8	pier, wharf, wharfage, dock
142	0.0003	228	18	crt screen
143	0.0003	570	59	plate
144	0.0003	217	22	monitor, monitoring device
145	0.0003	206	19	bulletin board, notice board
146	0.0003	130	14	shower
147	0.0003	178	28	radiator
148	0.0002	504	57	glass, drinking glass
149	0.0002	775	96	clock
150	0.0002	421	56	flag

mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── cityscapes
│ │ ├── leftImg8bit
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── gtFine
│ │ │ ├── train
│ │ │ ├── val
│ ├── VOCdevkit
│ │ ├── VOC2012
│ │ │ ├── JPEGImages
│ │ │ ├── SegmentationClass
│ │ │ ├── ImageSets
│ │ │ │ ├── Segmentation
│ │ ├── VOC2010
│ │ │ ├── JPEGImages
│ │ │ ├── SegmentationClassContext
│ │ │ ├── ImageSets
│ │ │ │ ├── SegmentationContext
│ │ │ │ │ ├── train.txt
│ │ │ │ │ ├── val.txt
│ │ │ ├── trainval_merged.json
│ │ ├── VOCaug
│ │ │ ├── dataset
│ │ │ │ ├── cls
│ ├── ade
│ │ ├── ADEChallengeData2016
│ │ │ ├── annotations
│ │ │ │ ├── training
│ │ │ │ ├── validation
│ │ │ ├── images
│ │ │ │ ├── training
│ │ │ │ ├── validation
在这里插入图片描述

5、 修改基本配置文件

本次我们选择upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K模型进行训练,对应的配置文件如下。
在这里插入图片描述
具体配置信息如下

_base_ = ['../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py','../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth'  # noqa
model = dict(backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),embed_dims=96,depths=[2, 2, 6, 2],num_heads=[3, 6, 12, 24],window_size=7,use_abs_pos_embed=False,drop_path_rate=0.3,patch_norm=True),decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),auxiliary_head=dict(in_channels=384, num_classes=150))# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(_delete_=True,type='AdamW',lr=0.00006,betas=(0.9, 0.999),weight_decay=0.01,paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),'relative_position_bias_table': dict(decay_mult=0.),'norm': dict(decay_mult=0.)}))lr_config = dict(_delete_=True,policy='poly',warmup='linear',warmup_iters=1500,warmup_ratio=1e-6,power=1.0,min_lr=0.0,by_epoch=False)# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

1、设置修改类别数​和加载预训练模型(模型架构配置文件upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py)

_base_ = ['../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py','../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth'  # noqa,这个可以下载后,加载下载后的路径
model = dict(backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),embed_dims=96,depths=[2, 2, 6, 2],num_heads=[3, 6, 12, 24],window_size=7,use_abs_pos_embed=False,drop_path_rate=0.3,patch_norm=True),decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),auxiliary_head=dict(in_channels=384, num_classes=150))#num_classes修改为自己的数据类别数,不包括背景,背景自动为0# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(_delete_=True,type='AdamW',lr=0.00006,betas=(0.9, 0.999),weight_decay=0.01,paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),'relative_position_bias_table': dict(decay_mult=0.),'norm': dict(decay_mult=0.)}))lr_config = dict(_delete_=True,policy='poly',warmup='linear',warmup_iters=1500,warmup_ratio=1e-6,power=1.0,min_lr=0.0,by_epoch=False)# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

2、修改数据信息(数据类型、数据主路径等和batch-size)(‘…/base/datasets/ade20k.py’)

# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016' #1、修改为自己的数据路径
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512) #2、修改为自己的数据的尺寸
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations', reduce_zero_label=True),dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),#根据img_crop调整img_scaledict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),dict(type='RandomFlip', prob=0.5),dict(type='PhotoMetricDistortion'),dict(type='Normalize', **img_norm_cfg),dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),dict(type='DefaultFormatBundle'),dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='MultiScaleFlipAug',img_scale=(2048, 512),# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],flip=False,transforms=[dict(type='Resize', keep_ratio=True),dict(type='RandomFlip'),dict(type='Normalize', **img_norm_cfg),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img']),])
]
data = dict(samples_per_gpu=4,workers_per_gpu=4,train=dict(type=dataset_type,data_root=data_root,img_dir='images/training',ann_dir='annotations/training',pipeline=train_pipeline),val=dict(type=dataset_type,data_root=data_root,img_dir='images/validation',ann_dir='annotations/validation',pipeline=test_pipeline),test=dict(type=dataset_type,data_root=data_root,img_dir='images/validation',ann_dir='annotations/validation',pipeline=test_pipeline))

3 修该类别名称CLASSES以及后缀名(mmseg/datasets/ade.py、mmseg/datasets/custom.py)

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as ospimport mmcv
import numpy as np
from PIL import Imagefrom .builder import DATASETS
from .custom import CustomDataset@DATASETS.register_module()
class ADE20KDataset(CustomDataset):"""ADE20K dataset.In segmentation map annotation for ADE20K, 0 stands for background, whichis not included in 150 categories. ``reduce_zero_label`` is fixed to True.The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to'.png'."""CLASSES = ('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ','windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth','door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car','water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug','field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe','lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column','signboard', 'chest of drawers', 'counter', 'sand', 'sink','skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path','stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door','stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table','toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove','palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar','arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower','chandelier', 'awning', 'streetlight', 'booth', 'television receiver','airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister','escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van','ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything','swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent','bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank','trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake','dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce','vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen','plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass','clock', 'flag')#修改为自己数据集的类别名称PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],[102, 255, 0], [92, 0, 255]] #同理可以修改颜色def __init__(self, **kwargs):super(ADE20KDataset, self).__init__(img_suffix='.jpg', #可以修改数据集的后缀格式seg_map_suffix='.png',#可以修改数据集标签的后缀格式reduce_zero_label=True,**kwargs)def results2img(self, results, imgfile_prefix, to_label_id, indices=None):"""Write the segmentation results to images.Args:results (list[ndarray]): Testing results of thedataset.imgfile_prefix (str): The filename prefix of the png files.If the prefix is "somepath/xxx",the png files will be named "somepath/xxx.png".to_label_id (bool): whether convert output to label_id forsubmission.indices (list[int], optional): Indices of input results, if notset, all the indices of the dataset will be used.Default: None.Returns:list[str: str]: result txt files which contains correspondingsemantic segmentation images."""if indices is None:indices = list(range(len(self)))mmcv.mkdir_or_exist(imgfile_prefix)result_files = []for result, idx in zip(results, indices):filename = self.img_infos[idx]['filename']basename = osp.splitext(osp.basename(filename))[0]png_filename = osp.join(imgfile_prefix, f'{basename}.png')#这里可以修改.png# The  index range of official requirement is from 0 to 150.# But the index range of output is from 0 to 149.# That is because we set reduce_zero_label=True.result = result + 1output = Image.fromarray(result.astype(np.uint8))output.save(png_filename)result_files.append(png_filename)return result_filesdef format_results(self,results,imgfile_prefix,to_label_id=True,indices=None):"""Format the results into dir (standard format for ade20k evaluation).Args:results (list): Testing results of the dataset.imgfile_prefix (str | None): The prefix of images files. Itincludes the file path and the prefix of filename, e.g.,"a/b/prefix".to_label_id (bool): whether convert output to label_id forsubmission. Default: Falseindices (list[int], optional): Indices of input results, if notset, all the indices of the dataset will be used.Default: None.Returns:tuple: (result_files, tmp_dir), result_files is a list containingthe image paths, tmp_dir is the temporal directory createdfor saving json/png files when img_prefix is not specified."""if indices is None:indices = list(range(len(self)))assert isinstance(results, list), 'results must be a list.'assert isinstance(indices, list), 'indices must be a list.'result_files = self.results2img(results, imgfile_prefix, to_label_id,indices)return result_files

有一点需要注意的是,如果你的图片是jpg合式,mask是png格式,应该没问题,要是不是这两种格式的话,需要在mmseg/datasets/custom.py中修改你的图片的格式。

在这里插入图片描述

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from collections import OrderedDictimport mmcv
import numpy as np
from mmcv.utils import print_log
from prettytable import PrettyTable
from torch.utils.data import Datasetfrom mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose, LoadAnnotations@DATASETS.register_module()
class CustomDataset(Dataset):"""Custom dataset for semantic segmentation. An example of file structureis as followed... code-block:: none├── data│   ├── my_dataset│   │   ├── img_dir│   │   │   ├── train│   │   │   │   ├── xxx{img_suffix}│   │   │   │   ├── yyy{img_suffix}│   │   │   │   ├── zzz{img_suffix}│   │   │   ├── val│   │   ├── ann_dir│   │   │   ├── train│   │   │   │   ├── xxx{seg_map_suffix}│   │   │   │   ├── yyy{seg_map_suffix}│   │   │   │   ├── zzz{seg_map_suffix}│   │   │   ├── valThe img/gt_semantic_seg pair of CustomDataset should be of the sameexcept suffix. A valid img/gt_semantic_seg filename pair should be like``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also includedin the suffix). If split is given, then ``xxx`` is specified in txt file.Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.Args:pipeline (list[dict]): Processing pipelineimg_dir (str): Path to image directoryimg_suffix (str): Suffix of images. Default: '.jpg'ann_dir (str, optional): Path to annotation directory. Default: Noneseg_map_suffix (str): Suffix of segmentation maps. Default: '.png'split (str, optional): Split txt file. If split is specified, onlyfile with suffix in the splits will be loaded. Otherwise, allimages in img_dir/ann_dir will be loaded. Default: Nonedata_root (str, optional): Data root for img_dir/ann_dir. Default:None.test_mode (bool): If test_mode=True, gt wouldn't be loaded.ignore_index (int): The label index to be ignored. Default: 255reduce_zero_label (bool): Whether to mark label zero as ignored.Default: Falseclasses (str | Sequence[str], optional): Specify classes to load.If is None, ``cls.CLASSES`` will be used. Default: None.palette (Sequence[Sequence[int]]] | np.ndarray | None):The palette of segmentation map. If None is given, andself.PALETTE is None, random palette will be generated.Default: Nonegt_seg_map_loader_cfg (dict, optional): build LoadAnnotations toload gt for evaluation, load from disk by default. Default: None.file_client_args (dict): Arguments to instantiate a FileClient.See :class:`mmcv.fileio.FileClient` for details.Defaults to ``dict(backend='disk')``."""CLASSES = NonePALETTE = Nonedef __init__(self,pipeline,img_dir,img_suffix='.jpg',#修改ann_dir=None,seg_map_suffix='.png',修改split=None,data_root=None,test_mode=False,ignore_index=255,reduce_zero_label=False,classes=None,palette=None,gt_seg_map_loader_cfg=None,file_client_args=dict(backend='disk')):self.pipeline = Compose(pipeline)self.img_dir = img_dirself.img_suffix = img_suffixself.ann_dir = ann_dirself.seg_map_suffix = seg_map_suffixself.split = splitself.data_root = data_rootself.test_mode = test_modeself.ignore_index = ignore_indexself.reduce_zero_label = reduce_zero_labelself.label_map = Noneself.CLASSES, self.PALETTE = self.get_classes_and_palette(classes, palette)self.gt_seg_map_loader = LoadAnnotations() if gt_seg_map_loader_cfg is None else LoadAnnotations(**gt_seg_map_loader_cfg)self.file_client_args = file_client_argsself.file_client = mmcv.FileClient.infer_client(self.file_client_args)if test_mode:assert self.CLASSES is not None, \'`cls.CLASSES` or `classes` should be specified when testing'# join paths if data_root is specifiedif self.data_root is not None:if not osp.isabs(self.img_dir):self.img_dir = osp.join(self.data_root, self.img_dir)if not (self.ann_dir is None or osp.isabs(self.ann_dir)):self.ann_dir = osp.join(self.data_root, self.ann_dir)if not (self.split is None or osp.isabs(self.split)):self.split = osp.join(self.data_root, self.split)# load annotationsself.img_infos = self.load_annotations(self.img_dir, self.img_suffix,self.ann_dir,self.seg_map_suffix, self.split)def __len__(self):"""Total number of samples of data."""return len(self.img_infos)def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,split):"""Load annotation from directory.Args:img_dir (str): Path to image directoryimg_suffix (str): Suffix of images.ann_dir (str|None): Path to annotation directory.seg_map_suffix (str|None): Suffix of segmentation maps.split (str|None): Split txt file. If split is specified, only filewith suffix in the splits will be loaded. Otherwise, all imagesin img_dir/ann_dir will be loaded. Default: NoneReturns:list[dict]: All image info of dataset."""img_infos = []if split is not None:lines = mmcv.list_from_file(split, file_client_args=self.file_client_args)for line in lines:img_name = line.strip()img_info = dict(filename=img_name + img_suffix)if ann_dir is not None:seg_map = img_name + seg_map_suffiximg_info['ann'] = dict(seg_map=seg_map)img_infos.append(img_info)else:for img in self.file_client.list_dir_or_file(dir_path=img_dir,list_dir=False,suffix=img_suffix,recursive=True):img_info = dict(filename=img)if ann_dir is not None:seg_map = img.replace(img_suffix, seg_map_suffix)img_info['ann'] = dict(seg_map=seg_map)img_infos.append(img_info)img_infos = sorted(img_infos, key=lambda x: x['filename'])print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())return img_infosdef get_ann_info(self, idx):"""Get annotation by index.Args:idx (int): Index of data.Returns:dict: Annotation info of specified index."""return self.img_infos[idx]['ann']def pre_pipeline(self, results):"""Prepare results dict for pipeline."""results['seg_fields'] = []results['img_prefix'] = self.img_dirresults['seg_prefix'] = self.ann_dirif self.custom_classes:results['label_map'] = self.label_mapdef __getitem__(self, idx):"""Get training/test data after pipeline.Args:idx (int): Index of data.Returns:dict: Training/test data (with annotation if `test_mode` is setFalse)."""if self.test_mode:return self.prepare_test_img(idx)else:return self.prepare_train_img(idx)def prepare_train_img(self, idx):"""Get training data and annotations after pipeline.Args:idx (int): Index of data.Returns:dict: Training data and annotation after pipeline with new keysintroduced by pipeline."""img_info = self.img_infos[idx]ann_info = self.get_ann_info(idx)results = dict(img_info=img_info, ann_info=ann_info)self.pre_pipeline(results)return self.pipeline(results)def prepare_test_img(self, idx):"""Get testing data after pipeline.Args:idx (int): Index of data.Returns:dict: Testing data after pipeline with new keys introduced bypipeline."""img_info = self.img_infos[idx]results = dict(img_info=img_info)self.pre_pipeline(results)return self.pipeline(results)def format_results(self, results, imgfile_prefix, indices=None, **kwargs):"""Place holder to format result to dataset specific output."""raise NotImplementedErrordef get_gt_seg_map_by_idx(self, index):"""Get one ground truth segmentation map for evaluation."""ann_info = self.get_ann_info(index)results = dict(ann_info=ann_info)self.pre_pipeline(results)self.gt_seg_map_loader(results)return results['gt_semantic_seg']def get_gt_seg_maps(self, efficient_test=None):"""Get ground truth segmentation maps for evaluation."""if efficient_test is not None:warnings.warn('DeprecationWarning: ``efficient_test`` has been deprecated ''since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory ''friendly by default. ')for idx in range(len(self)):ann_info = self.get_ann_info(idx)results = dict(ann_info=ann_info)self.pre_pipeline(results)self.gt_seg_map_loader(results)yield results['gt_semantic_seg']def pre_eval(self, preds, indices):"""Collect eval result from each iteration.Args:preds (list[torch.Tensor] | torch.Tensor): the segmentation logitafter argmax, shape (N, H, W).indices (list[int] | int): the prediction related ground truthindices.Returns:list[torch.Tensor]: (area_intersect, area_union, area_prediction,area_ground_truth)."""# In order to compat with batch inferenceif not isinstance(indices, list):indices = [indices]if not isinstance(preds, list):preds = [preds]pre_eval_results = []for pred, index in zip(preds, indices):seg_map = self.get_gt_seg_map_by_idx(index)pre_eval_results.append(intersect_and_union(pred,seg_map,len(self.CLASSES),self.ignore_index,# as the labels has been converted when dataset initialized# in `get_palette_for_custom_classes ` this `label_map`# should be `dict()`, see# https://github.com/open-mmlab/mmsegmentation/issues/1415# for more ditailslabel_map=dict(),reduce_zero_label=self.reduce_zero_label))return pre_eval_resultsdef get_classes_and_palette(self, classes=None, palette=None):"""Get class names of current dataset.Args:classes (Sequence[str] | str | None): If classes is None, usedefault CLASSES defined by builtin dataset. If classes is astring, take it as a file name. The file contains the name ofclasses where each line contains one class name. If classes isa tuple or list, override the CLASSES defined by the dataset.palette (Sequence[Sequence[int]]] | np.ndarray | None):The palette of segmentation map. If None is given, randompalette will be generated. Default: None"""if classes is None:self.custom_classes = Falsereturn self.CLASSES, self.PALETTEself.custom_classes = Trueif isinstance(classes, str):# take it as a file pathclass_names = mmcv.list_from_file(classes)elif isinstance(classes, (tuple, list)):class_names = classeselse:raise ValueError(f'Unsupported type {type(classes)} of classes.')if self.CLASSES:if not set(class_names).issubset(self.CLASSES):raise ValueError('classes is not a subset of CLASSES.')# dictionary, its keys are the old label ids and its values# are the new label ids.# used for changing pixel labels in load_annotations.self.label_map = {}for i, c in enumerate(self.CLASSES):if c not in class_names:self.label_map[i] = -1else:self.label_map[i] = class_names.index(c)palette = self.get_palette_for_custom_classes(class_names, palette)return class_names, palettedef get_palette_for_custom_classes(self, class_names, palette=None):if self.label_map is not None:# return subset of palettepalette = []for old_id, new_id in sorted(self.label_map.items(), key=lambda x: x[1]):if new_id != -1:palette.append(self.PALETTE[old_id])palette = type(self.PALETTE)(palette)elif palette is None:if self.PALETTE is None:# Get random state before set seed, and restore# random state later.# It will prevent loss of randomness, as the palette# may be different in each iteration if not specified.# See: https://github.com/open-mmlab/mmdetection/issues/5844state = np.random.get_state()np.random.seed(42)# random palettepalette = np.random.randint(0, 255, size=(len(class_names), 3))np.random.set_state(state)else:palette = self.PALETTEreturn palettedef evaluate(self,results,metric='mIoU',logger=None,gt_seg_maps=None,**kwargs):"""Evaluate the dataset.Args:results (list[tuple[torch.Tensor]] | list[str]): per image pre_evalresults or predict segmentation map for computing evaluationmetric.metric (str | list[str]): Metrics to be evaluated. 'mIoU','mDice' and 'mFscore' are supported.logger (logging.Logger | None | str): Logger used for printingrelated information during evaluation. Default: None.gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,used in ConcatDatasetReturns:dict[str, float]: Default metrics."""if isinstance(metric, str):metric = [metric]allowed_metrics = ['mIoU', 'mDice', 'mFscore']if not set(metric).issubset(set(allowed_metrics)):raise KeyError('metric {} is not supported'.format(metric))eval_results = {}# test a list of filesif mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(results, str):if gt_seg_maps is None:gt_seg_maps = self.get_gt_seg_maps()num_classes = len(self.CLASSES)ret_metrics = eval_metrics(results,gt_seg_maps,num_classes,self.ignore_index,metric,label_map=dict(),reduce_zero_label=self.reduce_zero_label)# test a list of pre_eval_resultselse:ret_metrics = pre_eval_to_metrics(results, metric)# Because dataset.CLASSES is required for per-eval.if self.CLASSES is None:class_names = tuple(range(num_classes))else:class_names = self.CLASSES# summary tableret_metrics_summary = OrderedDict({ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)for ret_metric, ret_metric_value in ret_metrics.items()})# each class tableret_metrics.pop('aAcc', None)ret_metrics_class = OrderedDict({ret_metric: np.round(ret_metric_value * 100, 2)for ret_metric, ret_metric_value in ret_metrics.items()})ret_metrics_class.update({'Class': class_names})ret_metrics_class.move_to_end('Class', last=False)# for loggerclass_table_data = PrettyTable()for key, val in ret_metrics_class.items():class_table_data.add_column(key, val)summary_table_data = PrettyTable()for key, val in ret_metrics_summary.items():if key == 'aAcc':summary_table_data.add_column(key, [val])else:summary_table_data.add_column('m' + key, [val])print_log('per class results:', logger)print_log('\n' + class_table_data.get_string(), logger=logger)print_log('Summary:', logger)print_log('\n' + summary_table_data.get_string(), logger=logger)# each metric dictfor key, value in ret_metrics_summary.items():if key == 'aAcc':eval_results[key] = value / 100.0else:eval_results['m' + key] = value / 100.0ret_metrics_class.pop('Class', None)for key, value in ret_metrics_class.items():eval_results.update({key + '.' + str(name): value[idx] / 100.0for idx, name in enumerate(class_names)})return eval_results

4、修改运行信息配置(加载预训练模型和断点训练)(configs/-base-/default_runtime.py)

# yapf:disable
log_config = dict(interval=50,hooks=[dict(type='TextLoggerHook', by_epoch=False),# dict(type='TensorboardLoggerHook') #开启TensorboardLoggerHook# dict(type='PaviLoggerHook') # for internal services])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None  #从给定的路径加载模型作为预先训练的模型,这不会恢复训练。
resume_from = None  #从给定的路径加载模型作为训练后的断点的模型,恢复训练。
workflow = [('train', 1)]
cudnn_benchmark = True

5、修改运行信息配置(模型训练的最大次数、训练每个几次保留一个checkpoints、间隔多少次进行模型训练,模型训练评估的指标为、保留最好的模型)(configs/-base-/schedule_40k.py、…/base/schedules/schedule_160k.py)

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)#max_iters,模型训练的最大迭代次数
checkpoint_config = dict(by_epoch=False, interval=16000)##interval,模型保存的迭代次数
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)#interval=16000模型多少间隔训练一次,评估的指标,#save_best='auto'可以保留最好的模型

单个GPU学习率lr= LR*(batch_size/16),LR代表4GPU的学习率

6、修改模型的推理模式以及norm_cfg(…/base/models/upernet_swin.py)

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)#这里的norm_cfg中,如果是多卡训练,采用“SyncBN”; 如果是单卡训练,将type修改为'BN'即可。
backbone_norm_cfg = dict(type='LN', requires_grad=True)
model = dict(type='EncoderDecoder',pretrained=None,backbone=dict(type='SwinTransformer',pretrain_img_size=224,embed_dims=96,patch_size=4,window_size=7,mlp_ratio=4,depths=[2, 2, 6, 2],num_heads=[3, 6, 12, 24],strides=(4, 2, 2, 2),out_indices=(0, 1, 2, 3),qkv_bias=True,qk_scale=None,patch_norm=True,drop_rate=0.,attn_drop_rate=0.,drop_path_rate=0.3,use_abs_pos_embed=False,act_cfg=dict(type='GELU'),norm_cfg=backbone_norm_cfg),decode_head=dict(type='UPerHead',in_channels=[96, 192, 384, 768],in_index=[0, 1, 2, 3],pool_scales=(1, 2, 3, 6),channels=512,dropout_ratio=0.1,num_classes=19,norm_cfg=norm_cfg,align_corners=False,loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),auxiliary_head=dict(type='FCNHead',in_channels=384,in_index=2,channels=256,num_convs=1,concat_input=False,dropout_ratio=0.1,num_classes=19,norm_cfg=norm_cfg,align_corners=False,loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),# model training and testing settingstrain_cfg=dict(),test_cfg=dict(mode='whole'))#'whole代表全图推理模式',
#滑窗重叠预测可修改为:test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341))

在这里插入图片描述
滑动窗口代码:mmsegmentation/mmseg/models/segmentors/encoder_decoder.py

    # TODO refactordef slide_inference(self, img, img_meta, rescale):"""Inference by sliding-window with overlap.If h_crop > h_img or w_crop > w_img, the small patch will be used todecode without padding."""h_stride, w_stride = self.test_cfg.strideh_crop, w_crop = self.test_cfg.crop_sizebatch_size, _, h_img, w_img = img.size()num_classes = self.num_classesh_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1preds = img.new_zeros((batch_size, num_classes, h_img, w_img))count_mat = img.new_zeros((batch_size, 1, h_img, w_img))for h_idx in range(h_grids):for w_idx in range(w_grids):y1 = h_idx * h_stridex1 = w_idx * w_stridey2 = min(y1 + h_crop, h_img)x2 = min(x1 + w_crop, w_img)y1 = max(y2 - h_crop, 0)x1 = max(x2 - w_crop, 0)crop_img = img[:, :, y1:y2, x1:x2]crop_seg_logit = self.encode_decode(crop_img, img_meta)preds += F.pad(crop_seg_logit,(int(x1), int(preds.shape[3] - x2), int(y1),int(preds.shape[2] - y2)))count_mat[:, :, y1:y2, x1:x2] += 1assert (count_mat == 0).sum() == 0if torch.onnx.is_in_onnx_export():# cast count_mat to constant while exporting to ONNXcount_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)preds = preds / count_matif rescale:# remove padding arearesize_shape = img_meta[0]['img_shape'][:2]preds = preds[:, :, :resize_shape[0], :resize_shape[1]]preds = resize(preds,size=img_meta[0]['ori_shape'][:2],mode='bilinear',align_corners=self.align_corners,warning=False)return preds

6、模型优化技巧

1、学习率优化技巧

在语义分割中,一些方法使头部的 LR 大于骨干,以实现更好的性能或更快的收敛。
在 MMSegmentation 中,您可以在配置中添加以下行,以使 head 的 LR 是主干的 10 倍。通过此修改,任何具有 LR名称的参数组的 LR’head’都将乘以 10。

Different Learning Rate(LR) for Backbone and Heads
n MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone.optimizer=dict(paramwise_cfg = dict(custom_keys={'head': dict(lr_mult=10.)}))

2、Online Hard Example Mining (OHEM)

我们在这里实现像素采样器用于训练采样。这是一个启用 OHEM 的 PSPNet 训练示例配置。
这样,只使用置信度分数低于 0.7 的像素进行训练。我们在训练期间至少保留 100000 像素。如果thresh未指定,min_kept将选择顶部丢失的像素。

Online Hard Example Mining (OHEM)
We implement pixel sampler here for training sampling. Here is an example config of training PSPNet with OHEM enabled._base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(decode_head=dict(sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )

3、类平衡损失

对于类别分布不平衡的数据集,您可以更改每个类别的损失权重。这是城市景观数据集的示例。class_weight 将作为weight参数传入CrossEntropyLoss

_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(decode_head=dict(loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,# DeepLab used this class weight for cityscapesclass_weight=[0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,1.0865, 1.0955, 1.0865, 1.1529, 1.0507])))

4、多重损失

对于损失计算,我们支持同时进行多个损失训练。unet这是一个在数据集上训练的示例配置DRIVE,其损失函数是1:3和 的加权CrossEntropyLoss和DiceLoss:

_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),)

这样,loss_weight和loss_name将分别是对应损失的训练日志中的权重和名称。
注意:如果要将此损失项包含到后向图中,loss_必须是名称的前缀。

5、在损失计算中忽略指定的标签索引

在默认设置中,avg_non_ignore=False这意味着每个像素都计入损失计算,尽管其中一些属于忽略索引标签。
对于损失计算,我们支持通过avg_non_ignore和忽略某些标签的索引ignore_index。这样,平均损失只会在非忽略标签中计算,可能会获得更好的性能,这里是参考。unet这是数据集训练的示例配置Cityscapes:在损失计算中,它将忽略作为背景的标签 0,并且仅在非忽略标签上计算损失平均值:

_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(decode_head=dict(ignore_index=0,loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),auxiliary_head=dict(ignore_index=0,loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),))

只需添加ignore_index解码器头或辅助头并添加avg_non_ignore=True:

# model settings
...loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
...

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_218392.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一篇文章带你了解服务器操作系统——Linux简单入门

一篇文章带你了解服务器操作系统——Linux简单入门 Linux作为服务器的常用操作系统,身为工作人员自然是要有所了解的 在本篇中我们会简单介绍Linux的特点,安装,相关指令使用以及内部程序的安装等本篇内容属于《瑞吉外卖》的知识科普部分,有兴趣可以查看一下《瑞吉外卖》的相…

欧拉路径(欧拉环游、欧拉回路)

一个流行的游戏是用铅笔画这些图,但是图中的每一条边都只能被画一次,在画图过程中铅笔不能离开纸面。难度更高的问题是,不光要一笔画完图,并且起点和终点还要落在同一处。如果我们将上面的三个图形都看作图数据结构,那…

flash动画设计并发布、嵌入到网页

【创意内容】 Flash动画设计,二维动画自己选择了动画主题,有三个板块:bubbles动画、蝴蝶飞动画、全球游线图动画,都是自己做的,使用了场景运用动画、图片的滚动、形状遮罩等功能。 【程序运行截图】 bubbles butterflies global

ICCV 2021 | Y-Net:轨迹-场景信息的真正融合

今天没有多余的解释,直接开始吧~ 1. Y-Net网络结构 Y-Net的网络结构长什么样子呢?Y-Net的网络结构就长下图这样子。看上去我好像在自言自语,其实你仔细揣摩就会发现,我真的是在自言自语。可以看到说,Y-Net网络输入的是…

TPH-YOLOv5: 基于Transformer预测头的改进YOLOv5用于无人机捕获场景目标检测

代码链接:GitHub - cv516Buaa/tph-yolov5 这是一篇针对无人机小目标算法比赛后写的论文,无人机捕获场景下的目标检测是近年来的热门课题。由于无人机总是在不同的高度上飞行,目标尺度变化剧烈,给网络优化带来了负担。此外&#xf…

buu [NPUCTF2020]认清形势,建立信心

题目: from Crypto.Util.number import * from gmpy2 import * from secret import flagp getPrime(25) e # Hidden q getPrime(25) n p * q m bytes_to_long(flag.strip(b"npuctf{").strip(b"}"))c pow(m, e, n) print(c) print(pow(2,…

hadoop至MapReduce-004

MapReduce定义 MapReduce是一个分布式运算程序的编程框架,核心功能是将用户编写的业务逻辑代码和自带默认组件组合成一个完整的分布式运算程序,并发运行在hadoop集群上 MapReduce的优缺点 优点 易于编程:用户只关心业务逻辑代码扩展性&am…

webpack 异步import生成代码解析

文章目录原文件内容文件目录打包前打包后入口文件生成代码生成的一些辅助方法__webpack_require__.m__webpack_require__.d__webpack_require__.o__webpack_require__.u__webpack_require__.g__webpack_require__.r导入文件通用方法__webpack_require__异步文件引入获取下载文件…

AntDB-M设计之CheckPoint

1.引 言 数据库服务能力提升是一项系统性的工程,在不同的应用场景下,用户对于数据库各项能力的关注点也不同,如:读写延迟、吞吐量、扩展性、可靠性、可用性等等。国内不少数据库系统通过系统架构优化、硬件设备升级等方式&…

教程:使用Jmeter对带token的接口进行压测

最近在研究并发,用到了Jmeter对接口进行压力测试,记录下使用过程 一. 配置/bin下的Jmeter.properties,打开以下两项配置,一个是默认的编码,一个是默认的语言 二. 打开jmeter.bat运行,新建线程组&#xff0…

qt学习笔记6:ui实例 登录窗口布局

首先从ui布局界面去进行大致布局, 可以先把默认的一些移除掉,变成一个大的空窗口 用户窗口,一般都得有一个用户名和密码(用label)输入用Line edit, 再来俩按钮pushButton, 但仅仅这样是没有意义…

kafka学习(四):生产者发送消息的分区策略

Kafka为了增加系统的伸缩性(Scalability),引入了分区(Partitioning)的概念。 Kafka 中的分区机制指的是将每个主题划分成多个分区(Partition),每个分区是一组有序的消息日志。主题下的每条消息只会保存在某一个分区中,…

python 基于PHP在线音乐网站

随着时代的发展,人们的生活水平越来越高,相对应的对精神世界的追求也越来越多,而音乐一直以来一直是人们追求美好生活的象征,它不仅可以陶冶人们的情操还可以美化人们的灵魂,音乐也一直是千百年来人们不断追求的一个精神文明的产物,为了能够让更多的人找到自己喜欢的音乐,我开发…

1.3.1操作系统的运行机制和体系结构

文章目录运行机制两种指令两种状态两种程序操作系统内核内核在计算机的系统中的层次结构内核的功能时钟管理(基本功能)中断机制(基本功能)原语(基本功能)对资源的进行管理的功能运行机制 两种指令 指令和…

python基于PHP旅游网站的设计与开发

在经济高速发展的现在,人们的工作越来越繁重,生活节奏越来越快,生活工作压力也越来越大。反而留给自己休息,享受旅游生活的时间越来越少,缺少对周边旅游信息的了解,无法与兴趣一致的户外旅友进行交流。这则会导致人们会花更多的时间去寻找旅游地点,并进行路线规划,花费的时间在…

彻底理解闭包实现原理

前言 闭包对于一个长期写 Java 的开发者来说估计鲜有耳闻,我在写 Python 和 Go 之前也是没怎么了解,光这名字感觉就有点"神秘莫测",这篇文章的主要目的就是从编译器的角度来分析闭包,彻底搞懂闭包的实现原理。 函数一等公民 一门语言在实现闭包之前首先要具有的特…

工程项目部质量管理体系的控制要点分析

质量管理是施工企业风险控制的重要组成部分。本文从有序的生产过程控制,提高企业质量意识出发,结合贯彻ISO9001标准及50430规范的企业贯标工作,分阶段研究和分析施工企业工程项目部质量管理体系的控制要点。 质量是企业的生命线,…

Android实战——单元测试从吹水到实践

目录1.单元测试到底需要不需要了?开发时间紧张,不需要做单元测试了吧?开发经验丰富,不需要做单元测试了吧?或许存在一种”自动化“的测试,就不需要做单元测试了吧?2.单元测试的好处单元测试可以…

【附源码】计算机毕业设计SSM校园拍卖平台

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

React 状态管理器,我是这样选的

前言 我们的前端团队在一直深度使用 React ,从最早的 CRA ,到后来切换到 umijs ,从 1.x、2.x、3.x 再到现在的 4.x,其中有一点不变的,就是我们一直在使用基于 react-redux 思想的 dva 作为状态管理工具。 在状态共享这…