segformer의 offical github를 가져와서 custom data로 학습시키는 것을 해보겠다
0. 우선 기본적인 세팅은 끝마쳤다고 생각하겠다
(apt update, upgrade, conda 또는 docker를 이용해 환경 설정 등)
https://daeun-computer-uneasy.tistory.com/126 이 링크 참조하시면 좋을 것 같습니다.
mmcv 같은 경우, cuda 및 torch 버전에 엄청난 영향을 받아서,, 꼬일 대로 꼬이면 답도 없습니다 ㅜㅜ
1. custom data set 폴더 설정
data/custom/images/train/*.jpg
data/custom/images/val/*.jpg
data/custom/annotations/train/*.png
data/custom/annotations/val/*.png
annotations 이미지들은 png로 확장자를 변경해주자. (확장자 변경하는 코드도 이 블로그의 파이썬 카테고리에 있다!)
2. local_configs/_base_/datasets/custom.py 생성(or 수정)
이미 있다면 수정해주고, 없다면 만들어주자.
# dataset settings
dataset_type = 'CustomDataset'
data_root = 'data/custom/'
img_norm_cfg = dict(
mean=[], std=[], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/train',
ann_dir='annotations/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/val',
ann_dir='annotations/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/val',
ann_dir='annotations/val',
pipeline=test_pipeline))
다른 파일들과 다른 점은 없지만, 어떤 데이터셋을 쓸건지, 경로는 어디인지, mean과 std는 무엇인지 등 만을 수정해주었다.
(mean이랑 std 구하는 코드도 이 블로그의 파이썬 카테고리에 있다!)
여기서 samples per gpu는 배치사이즈고, workers per gpu은 데이터셋을 gpu로 넘기는 subprocess를 말한다.
3. config 파일 수정
segformer/B0~5/segformer.512x512.~~~.py 등이 있는데 원하는 것 아무거나 골라서 수정해주면 된다.
_base_ = [
#'../../_base_/models/segformer.py',
'../../_base_/datasets/custom.py',
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_160k_adamw.py'
]
# model settings
norm_cfg = dict(type='BN', requires_grad=True)
find_unused_parameters = True
model = dict(
type='EncoderDecoder',
# pretrained='local_configs/segformer/B0/pretrained/mit_b0.pth',
pretrained= None,
backbone=dict(
type='mit_b0',
style='pytorch'),
decode_head=dict(
type='SegFormerHead',
in_channels=[32, 64, 160, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
decoder_params=dict(embed_dim=256),
#decoder_params=dict(),
loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True, use_mask=False,loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
이것 역시 크게 수정한 부분은 없지만, 모델의 config가 되는 부분이므로 바꿀 부분이 있다면 바꿔주자
ex) class 수, optimizer, pretrained 등 ,,
나는 binary segmentation을 진행해서 class 수를 2로 바꾸는 등 여러 작업을 했다.
4. 학습 돌리기!
원하는 config 파일명을 넣어서 돌려주면 된다.
python tools/train.py local_configs/segformer/B0/segformer.b1.512x512.ade.160k.py --gpus 1
----단일 gpu로 학습시킨다면, synBN을 BN으로 바꿔줄 것!-----
하기 전에 항상 데이터의 값을 찍어보자!
grayscale 이미지는 uint8의 범위 (0~255)를 가진다. 만약 binary segmentation을 할 경우, 이를 0과 1로 치환해야하므로 데이터를 load할 때 convert하는 코드를 넣어주면 된다.
아니면 loss단에
label = torch.where(label > 125, 1, 0)
이런 코드를 적어주면 되긴 하는데,, 추천하지 않는다.
'AI' 카테고리의 다른 글
RNN 첫걸음 (0) | 2022.09.28 |
---|---|
CNN 첫걸음 (0) | 2022.09.28 |
딥러닝 기초 - RNN 간단 정리 (0) | 2022.09.28 |
딥러닝 기초 - CNN 간단 정리 (1) | 2022.09.28 |
Regularization 간단 정리 (0) | 2022.09.28 |