BasicSR入门教程
1.安装环境
由于安装好的其他环境已经有了pytorch,那么新建环境时直接拷贝该环境就好
//复制环境
conda create --name my-basicsr --clone mmediting
克隆项目
git clone https://github.com/XPixelGroup/BasicSR.git
安装依赖包
cd BasicSR
pip install -r requirements.txt
在BasicSR的根目录下安装BasicSR
python setup.py develop
验证BasicSR是否安装成功
import basicsr
通过本地clone安装成功的时候,此时使用pip list 命令查看BasicSR 路径
pip list
2.准备数据集
常用的图像超分数据集如下:
name | 数据集 | 数据描述 | 下载 |
---|---|---|---|
2K Resolution | DIV2K | proposed in NTIRE17 (800 train and 100 validation) | official website |
Classical SR Testing | Set5 | Set5 test dataset | Google Drive / Baidu Drive |
Classical SR Testing | Set14 | Set14 test dataset | Google Drive / Baidu Drive |
DIV2K下载地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/
Set5下载地址:https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u
Set14下载地址:https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u
因为DIV2K 数据集是2K 分辨率的(比如: 2048×1080), 而我们在训练的时候往往并不要那么大(常见的是128×128 或者192×192 的训练patch). 因此我们可以先把2K 的图片裁剪成有overlap 的480×480 的子图像块. 然后再由dataloader 从这个480×480 的子图像块中随机crop 出128×128 或者192×192 的训练patch。运行脚本extract_subimages.py。
cd BasicSR
python scripts/data_preparation/extract_subimages.py
若需要使用LMDB,则需要制作LMDB,数据准备运行脚本:
python scripts/data_preparation/create_lmdb.py --dataset div2k
数据集的目录结构如下
3.修改配置文件
创建新的训练配置文件options/train/SRResNet_SRGAN/my_train_MSRResNet_x4.yml
,内容如下
# Modified SRResNet w/o BN from:
# Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network# ----------- Commands for running
# ----------- Single GPU with auto_resume
# PYTHONPATH="./:${PYTHONPATH}" CUDA_VISIBLE_DEVICES=0 python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml --auto_resume# general settings
name: 001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst
model_type: SRModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0# dataset and data loader settings
datasets:train:name: DIV2Ktype: PairedImageDataset# dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub# dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub# meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt# dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub# dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub# meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt# (for lmdb)dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdbdataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdbfilename_tmpl: '{}'io_backend:# type: disk# (for lmdb)type: lmdbgt_size: 128use_hflip: trueuse_rot: true# data loadernum_worker_per_gpu: 6batch_size_per_gpu: 16dataset_enlarge_ratio: 100prefetch_mode: ~val:name: Set5type: PairedImageDatasetdataroot_gt: datasets/Set5/GTmod12dataroot_lq: datasets/Set5/LRbicx4io_backend:type: diskval_2:name: Set14type: PairedImageDatasetdataroot_gt: datasets/Set14/GTmod12dataroot_lq: datasets/Set14/LRbicx4io_backend:type: disk# network structures
network_g:type: MSRResNetnum_in_ch: 3num_out_ch: 3num_feat: 64num_block: 16upscale: 4# path
path:pretrain_network_g: ~param_key_g: paramsstrict_load_g: trueresume_state: ~# training settings
train:ema_decay: 0.999optim_g:type: Adamlr: !!float 2e-4weight_decay: 0betas: [0.9, 0.99]scheduler:type: CosineAnnealingRestartLRperiods: [250000, 250000, 250000, 250000]restart_weights: [1, 1, 1, 1]eta_min: !!float 1e-7# total_iter: 1000000total_iter: 10000warmup_iter: -1 # no warm up# lossespixel_opt:type: L1Lossloss_weight: 1.0reduction: mean# validation settings
val:val_freq: !!float 5e3save_img: falsemetrics:psnr: # metric name, can be arbitrarytype: calculate_psnrcrop_border: 4test_y_channel: falsebetter: higher # the higher, the better. Default: higherniqe:type: calculate_niqecrop_border: 4better: lower # the lower, the better# logging settings
logger:print_freq: 100save_checkpoint_freq: !!float 5e3use_tb_logger: truewandb:project: ~resume_id: ~# dist training settings
dist_params:backend: ncclport: 29500
可以开始训练
python basicsr/train.py -opt options/train/SRResNet_SRGAN/my_train_MSRResNet_x4.yml
训练完成后,结果会保存在results文件夹下的001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst
文件夹中
创建新的测试配置文件options/test/SRResNet_SRGAN/my_test_MSRResNet_x4.yml
,内容如下
# ----------- Commands for running
# ----------- Single GPU
# PYTHONPATH="./:${PYTHONPATH}" CUDA_VISIBLE_DEVICES=0 python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml# general settings
name: 001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst
model_type: SRModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0# test dataset settings
datasets:test_1: # the 1st test datasetname: Set5type: PairedImageDatasetdataroot_gt: datasets/Set5/GTmod12dataroot_lq: datasets/Set5/LRbicx4io_backend:type: disktest_2: # the 2nd test datasetname: Set14type: PairedImageDatasetdataroot_gt: datasets/Set14/GTmod12dataroot_lq: datasets/Set14/LRbicx4io_backend:type: disktest_3: # the 3rd test datasetname: DIV2K100type: PairedImageDatasetdataroot_gt: datasets/DIV2K/DIV2K_valid_HRdataroot_lq: datasets/DIV2K/DIV2K_valid_LR_bicubic/X4filename_tmpl: '{}x4'io_backend:type: disk# network structures
network_g:type: MSRResNetnum_in_ch: 3num_out_ch: 3num_feat: 64num_block: 16upscale: 4# path
path:pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst/models/net_g_10000.pthparam_key_g: paramsstrict_load_g: true# validation settings
val:save_img: truesuffix: ~ # add suffix to saved images, if None, use exp namemetrics:psnr: # metric name, can be arbitrarytype: calculate_psnrcrop_border: 4test_y_channel: falsebetter: higher # the higher, the better. Default: higherssim:type: calculate_ssimcrop_border: 4test_y_channel: falsebetter: higher
测试完成后,结果会保存在results文件夹下的001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst
文件夹中
4.tensorboard可视化训练过程
在用于训练的yml配置文件中设置tensorboard开启
# logging settings
logger:print_freq: 100save_checkpoint_freq: !!float 5e3use_tb_logger: true # 设置为truewandb:project: ~resume_id: ~
在命令行输入以下命令,就可以在服务器的浏览器中查看:
tensorboard --logdir tb_logger --port 5500 --bind_all
tensorboard 在本机可以方便使用,但使用服务器时需要设置一下。
在Windows系统装一个Xshell,在文件->属性->ssh->隧道->添加,类型local,源主机填127.0.0.1(意思是本机),端口设置一个,比如12345,目标主机为服务器,目标端口一般是5500,如果5500被占了可以改为其他端口。
在本地浏览器中输入127.0.0.1:12345
即可
最后感谢小伙伴们的学习噢~