Windows下使用BasicSR训练自定义图像超分模型

博主毕设方向是图像超分辨率,做毕设的时候仿照EDSR写过一个训练框架,后面了解到有这个代码库但是一直没有时间做验证。这几天答辩完了抽空实验了一下这个库,这里做个记录,系统是win11。
项目GitHub地址:https://github.com/XPixelGroup/BasicSR

准备

  1. git项目到本地
git clone https://github.com/xinntao/BasicSR.git

注意:此处没有采用pip方式,而是直接git项目并在项目里面做修改

  1. 安装依赖包
    建议使用conda创建虚拟环境,不建议使用项目里面的requirement.txt,我尝试过但是pytorch安装的是CPU版本,所以自行安装以下依赖
    python>=3.7(basicsr项目推荐),我这里用的3.7.16
    pytorch>=1.7(basicsr项目推荐),我用的pytorch 1.11.0(记得安装GPU版的)
pip install addict future lmdb numpy opencv-python Pillow pyyaml requests scikit-image scipy tb-nightly tqdm yapf

注意,basicsr项目推荐numpy>=1.17,我这里直接pip最新的版本是1.21.6

数据集

采用DIV2K数据集,制作lmdb训练集(据说可以加速读取),验证集采用Set5。

训练集

DIV2K数据集下载地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/
需要准备HR图像,X2/X3/X4下采样的LR图像

  1. 裁剪数据集
    对于训练集,首先按照项目要求需要对数据集进行有重叠(overlap)的裁剪(crop),打开scripts/data_preparation/extract_subimages.py文件,修改HR、LR以及生成子图的路径。例如
# HR images
opt['input_folder'] = 'D:/Datasets/SISR/DIV2K/DIV2K_train_HR'
opt['save_folder'] = 'D:/Datasets/SISR/DIV2K/DIV2K_train_HR_sub'
opt['crop_size'] = 480
opt['step'] = 240
opt['thresh_size'] = 0
extract_subimages(opt)

# LRx2 images
opt['input_folder'] = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X2'
opt['save_folder'] = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
opt['crop_size'] = 240
opt['step'] = 120
opt['thresh_size'] = 0
extract_subimages(opt)

# LRx3 images
opt['input_folder'] = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X3'
opt['save_folder'] = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
opt['crop_size'] = 160
opt['step'] = 80
opt['thresh_size'] = 0
extract_subimages(opt)

# LRx4 images
opt['input_folder'] = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X4'
opt['save_folder'] = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
opt['crop_size'] = 120
opt['step'] = 60
opt['thresh_size'] = 0
extract_subimages(opt)

各参数的含义代码里面都有

修改完路径后直接运行extract_subimages.py,可在对应路径生产一系列子图。

  1. 生成lmdb数据集
    打开scripts/data_preparation/create_lmdb.py,修改create_lmdb_for_div2k函数中关于子图的路径以及生成lmdb数据集的路径,例如
def create_lmdb_for_div2k():
    """Create lmdb files for DIV2K dataset.

    Usage:
        Before run this script, please run `extract_subimages.py`.
        Typically, there are four folders to be processed for DIV2K dataset.

            * DIV2K_train_HR_sub
            * DIV2K_train_LR_bicubic/X2_sub
            * DIV2K_train_LR_bicubic/X3_sub
            * DIV2K_train_LR_bicubic/X4_sub

        Remember to modify opt configurations according to your settings.
    """
    # HR images
    folder_path = 'D:/Datasets/SISR/DIV2K/DIV2K_train_HR_sub'
    lmdb_path = 'D:/Datasets/SISR/DIV2K/DIV2K_train_HR_sub.lmdb'
    img_path_list, keys = prepare_keys_div2k(folder_path)
    make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)

    # LRx2 images
    folder_path = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
    lmdb_path = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic_X2_sub.lmdb'
    img_path_list, keys = prepare_keys_div2k(folder_path)
    make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)

    # LRx3 images
    folder_path = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
    lmdb_path = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic_X3_sub.lmdb'
    img_path_list, keys = prepare_keys_div2k(folder_path)
    make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)

    # LRx4 images
    folder_path = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
    lmdb_path = 'D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb'
    img_path_list, keys = prepare_keys_div2k(folder_path)
    make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)

验证集

验证集一般比较小,所以直接以图片形式存放就好,这里我使用Set5作为验证集。
下载地址:
链接:https://pan.baidu.com/s/1mU6uoGais7r_CAkrr3VGGA?pwd=um12
提取码:um12
–来自百度网盘超级会员V3的分享

自定义模型

由于博主毕设方向是轻量化超分,故这里选取IMDN 4倍超分作为示例,其它模型也可以大致参考以下步骤。
IMDN项目地址:https://github.com/Zheng222/IMDN

定义模型结构

主要参考了IMDN项目中/model下的architecture.pyblock.py文件。其中architecture.py定义了IMDN的模型结构,block.py定义了模型的构造块。
basicsr/archs/下存放了模型的网络结构文件,我们需要在这个目录下创建我们的模型文件以及模型依赖的文件等,为了保证文件可以被basicsr自动扫描import,命名风格建议与basicsr项目保持一致。这里我创建了两个文件imdn_arch.pyimdn_utils.py分别用于实现architecture.pyblock.py中的内容。例如
imdn_arch.py的文件内容如下:

import torch.nn as nn
from basicsr.archs import imdn_util as B
import torch

from basicsr.utils.registry import ARCH_REGISTRY


@ARCH_REGISTRY.register()
class IMDN(nn.Module):
    def __init__(self, in_nc=3, nf=64, num_modules=6, out_nc=3, upscale=4):
        super(IMDN, self).__init__()

        self.fea_conv = B.conv_layer(in_nc, nf, kernel_size=3)

        # IMDBs
        self.IMDB1 = B.IMDModule(in_channels=nf)
        self.IMDB2 = B.IMDModule(in_channels=nf)
        self.IMDB3 = B.IMDModule(in_channels=nf)
        self.IMDB4 = B.IMDModule(in_channels=nf)
        self.IMDB5 = B.IMDModule(in_channels=nf)
        self.IMDB6 = B.IMDModule(in_channels=nf)
        self.c = B.conv_block(nf * num_modules, nf, kernel_size=1, act_type='lrelu')

        self.LR_conv = B.conv_layer(nf, nf, kernel_size=3)

        upsample_block = B.pixelshuffle_block
        self.upsampler = upsample_block(nf, out_nc, upscale_factor=upscale)


    def forward(self, input):
        out_fea = self.fea_conv(input)
        out_B1 = self.IMDB1(out_fea)
        out_B2 = self.IMDB2(out_B1)
        out_B3 = self.IMDB3(out_B2)
        out_B4 = self.IMDB4(out_B3)
        out_B5 = self.IMDB5(out_B4)
        out_B6 = self.IMDB6(out_B5)

        out_B = self.c(torch.cat([out_B1, out_B2, out_B3, out_B4, out_B5, out_B6], dim=1))
        out_lr = self.LR_conv(out_B) + out_fea
        output = self.upsampler(out_lr)
        return output

因为只是示例,所以我这里只用了architecture.py中的其中一个模型。这里有几个点需要注意,一是必须要使用from basicsr.utils.registry import ARCH_REGISTRY导入basicsr的注册机制;二是要在定义的模型类的上面加上一句@ARCH_REGISTRY.register()实现注册(关于注册机制具体细节需要参考basicsr手册)。这里我们留意以下类的名称以及实例化类时所需要的参数,后面定义yml文件时需要用到。

imdn_utils.py的文件内容如下,主要定义了一些IMDN使用的构造块。

import torch.nn as nn
from collections import OrderedDict
import torch


def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
    padding = int((kernel_size - 1) / 2) * dilation
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias, dilation=dilation,
                     groups=groups)


def norm(norm_type, nc):
    norm_type = norm_type.lower()
    if norm_type == 'batch':
        layer = nn.BatchNorm2d(nc, affine=True)
    elif norm_type == 'instance':
        layer = nn.InstanceNorm2d(nc, affine=False)
    else:
        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
    return layer


def pad(pad_type, padding):
    pad_type = pad_type.lower()
    if padding == 0:
        return None
    if pad_type == 'reflect':
        layer = nn.ReflectionPad2d(padding)
    elif pad_type == 'replicate':
        layer = nn.ReplicationPad2d(padding)
    else:
        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
    return layer


def get_valid_padding(kernel_size, dilation):
    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
    padding = (kernel_size - 1) // 2
    return padding


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
               pad_type='zero', norm_type=None, act_type='relu'):
    padding = get_valid_padding(kernel_size, dilation)
    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
    padding = padding if pad_type == 'zero' else 0

    c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
                  dilation=dilation, bias=bias, groups=groups)
    a = activation(act_type) if act_type else None
    n = norm(norm_type, out_nc) if norm_type else None
    return sequential(p, c, n, a)


def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1):
    act_type = act_type.lower()
    if act_type == 'relu':
        layer = nn.ReLU(inplace)
    elif act_type == 'lrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act_type == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    else:
        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
    return layer


class ShortcutBlock(nn.Module):
    def __init__(self, submodule):
        super(ShortcutBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = x + self.sub(x)
        return output


def mean_channels(F):
    assert (F.dim() == 4)
    spatial_sum = F.sum(3, keepdim=True).sum(2, keepdim=True)
    return spatial_sum / (F.size(2) * F.size(3))


def stdv_channels(F):
    assert (F.dim() == 4)
    F_mean = mean_channels(F)
    F_variance = (F - F_mean).pow(2).sum(3, keepdim=True).sum(2, keepdim=True) / (F.size(2) * F.size(3))
    return F_variance.pow(0.5)


def sequential(*args):
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)


# contrast-aware channel attention module
class CCALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CCALayer, self).__init__()

        self.contrast = stdv_channels
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.contrast(x) + self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

class IMDModule(nn.Module):
    def __init__(self, in_channels, distillation_rate=0.25):
        super(IMDModule, self).__init__()
        self.distilled_channels = int(in_channels * distillation_rate)
        self.remaining_channels = int(in_channels - self.distilled_channels)
        self.c1 = conv_layer(in_channels, in_channels, 3)
        self.c2 = conv_layer(self.remaining_channels, in_channels, 3)
        self.c3 = conv_layer(self.remaining_channels, in_channels, 3)
        self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3)
        self.act = activation('lrelu', neg_slope=0.05)
        self.c5 = conv_layer(in_channels, in_channels, 1)
        self.cca = CCALayer(self.distilled_channels * 4)

    def forward(self, input):
        out_c1 = self.act(self.c1(input))
        distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1)
        out_c2 = self.act(self.c2(remaining_c1))
        distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1)
        out_c3 = self.act(self.c3(remaining_c2))
        distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1)
        out_c4 = self.c4(remaining_c3)
        out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1)
        out_fused = self.c5(self.cca(out)) + input
        return out_fused

class IMDModule_speed(nn.Module):
    def __init__(self, in_channels, distillation_rate=0.25):
        super(IMDModule_speed, self).__init__()
        self.distilled_channels = int(in_channels * distillation_rate)
        self.remaining_channels = int(in_channels - self.distilled_channels)
        self.c1 = conv_layer(in_channels, in_channels, 3)
        self.c2 = conv_layer(self.remaining_channels, in_channels, 3)
        self.c3 = conv_layer(self.remaining_channels, in_channels, 3)
        self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3)
        self.act = activation('lrelu', neg_slope=0.05)
        self.c5 = conv_layer(self.distilled_channels * 4, in_channels, 1)

    def forward(self, input):
        out_c1 = self.act(self.c1(input))
        distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1)
        out_c2 = self.act(self.c2(remaining_c1))
        distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1)
        out_c3 = self.act(self.c3(remaining_c2))
        distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1)
        out_c4 = self.c4(remaining_c3)

        out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1)
        out_fused = self.c5(out) + input
        return out_fused

class IMDModule_Large(nn.Module):
    def __init__(self, in_channels, distillation_rate=1 / 4):
        super(IMDModule_Large, self).__init__()
        self.distilled_channels = int(in_channels * distillation_rate)  # 6
        self.remaining_channels = int(in_channels - self.distilled_channels)  # 18
        self.c1 = conv_layer(in_channels, in_channels, 3, bias=False)  # 24 --> 24
        self.c2 = conv_layer(self.remaining_channels, in_channels, 3, bias=False)  # 18 --> 24
        self.c3 = conv_layer(self.remaining_channels, in_channels, 3, bias=False)  # 18 --> 24
        self.c4 = conv_layer(self.remaining_channels, self.remaining_channels, 3, bias=False)  # 15 --> 15
        self.c5 = conv_layer(self.remaining_channels - self.distilled_channels,
                             self.remaining_channels - self.distilled_channels, 3, bias=False)  # 10 --> 10
        self.c6 = conv_layer(self.distilled_channels, self.distilled_channels, 3, bias=False)  # 5 --> 5
        self.act = activation('relu')
        self.c7 = conv_layer(self.distilled_channels * 6, in_channels, 1, bias=False)

    def forward(self, input):
        out_c1 = self.act(self.c1(input))  # 24 --> 24
        distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels),
                                                 dim=1)  # 6, 18
        out_c2 = self.act(self.c2(remaining_c1))  # 18 --> 24
        distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels),
                                                 dim=1)  # 6, 18
        out_c3 = self.act(self.c3(remaining_c2))  # 18 --> 24
        distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels),
                                                 dim=1)  # 6, 18
        out_c4 = self.act(self.c4(remaining_c3))  # 18 --> 18
        distilled_c4, remaining_c4 = torch.split(out_c4, (
        self.distilled_channels, self.remaining_channels - self.distilled_channels), dim=1)  # 6, 12
        out_c5 = self.act(self.c5(remaining_c4))  # 12 --> 12
        distilled_c5, remaining_c5 = torch.split(out_c5, (
        self.distilled_channels, self.remaining_channels - self.distilled_channels * 2), dim=1)  # 6, 6
        out_c6 = self.act(self.c6(remaining_c5))  # 6 --> 6

        out = torch.cat([distilled_c1, distilled_c2, distilled_c3, distilled_c4, distilled_c5, out_c6], dim=1)
        out_fused = self.c7(out) + input
        return out_fused


def pixelshuffle_block(in_channels, out_channels, upscale_factor=2, kernel_size=3, stride=1):
    conv = conv_layer(in_channels, out_channels * (upscale_factor ** 2), kernel_size, stride)
    pixel_shuffle = nn.PixelShuffle(upscale_factor)
    return sequential(conv, pixel_shuffle)

实际上imdn_utils.py的内容和block.py的内容并无什么不同,我们也可以把构造块的实现代码直接放到imdn_arch.py中并作修改即可。这里就不示范了。

定义模型yml文件

由于是SISR任务,因此主要参考了basicsr项目中的options/train/EDSR/下的yml文件。这里我直接创建yml文件,路径是options/train/IMDN/train_IMDN_x4.yml,文件内容如下

# general settings
name: IMDN_x4
model_type: SRModel
scale: 4
num_gpu: 1  # set num_gpu: 0 for cpu mode
manual_seed: 10

# dataset and data loader settings
datasets:
  train:
    name: DIV2K
    type: PairedImageDataset
#    dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
#    dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
    # (for lmdb)
    dataroot_gt: D:/Datasets/SISR/DIV2K/DIV2K_train_HR_sub.lmdb
    dataroot_lq: D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
    filename_tmpl: '{}'
    io_backend:
#      type: disk
      # (for lmdb)
      type: lmdb

    gt_size: 192
    use_hflip: true
    use_rot: true

    # data loader
    num_worker_per_gpu: 6
    batch_size_per_gpu: 16
    dataset_enlarge_ratio: 100
    prefetch_mode: ~

  val:
    name: DIV2K
    type: PairedImageDataset
    dataroot_gt: D:/Datasets/SISR/Set5/GTmod12
    dataroot_lq: D:/Datasets/SISR/Set5/LRbicx4
    io_backend:
      type: disk

# network structures
network_g:
  type: IMDN
  in_nc: 3
  nf: 64
  num_modules: 6
  out_nc: 3
  upscale: 4

# path
path:
  pretrain_network_g: ~
  strict_load_g: false
  resume_state: ~

# training settings
train:
  ema_decay: 0.999
  optim_g:
    type: Adam
    lr: !!float 1e-4
    weight_decay: 0
    betas: [0.9, 0.99]

  scheduler:
    type: MultiStepLR
    milestones: [200000]
    gamma: 0.5

  total_iter: 300000
  warmup_iter: -1  # no warm up

  # losses
  pixel_opt:
    type: L1Loss
    loss_weight: 1.0
    reduction: mean

# validation settings
val:
  val_freq: !!float 5e3
  save_img: false

  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 4
      test_y_channel: false

# logging settings
logger:
  print_freq: 100
  save_checkpoint_freq: !!float 5e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 29500

这里主要修改了以下几个参数:

  • name: IMDN_x4,表示任务名称,可以自定义
  • model_type: SRModel无需修改,SRModel类是图像超分辨率模型的基础类,已定义了基础的单张图像超分辨率模型等一系列功能,例如根据配置文件自定实例化相应的网络结构类、加载预训练网络、初始化训练设置(优化器、损失函数、学习率等)、迭代训练、优化参数、保存模型等。可以直接使用
  • 训练集的路径dataroot_gt: D:/Datasets/SISR/DIV2K/DIV2K_train_HR_sub.lmdbdataroot_lq: D:/Datasets/SISR/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
  • 训练集io_backendtype: lmdb,因为训练集是使用lmdb形式的
  • 验证集的路径dataroot_gt: D:/Datasets/SISR/Set5/GTmod12dataroot_lq: D:/Datasets/SISR/Set5/LRbicx4,由于验证集是以图片形式保存在本地的,所以默认type: disk即可
  • 关于网络结构的参数,可以回顾我们之前定义的imdn_arch.py文件,我们注册模型结构的类的名称是IMDN,因此type: IMDN,其它参数就是我们定义类时所需要的参数,都可以在此处配置。
network_g:
  type: IMDN
  in_nc: 3
  nf: 64
  num_modules: 6
  out_nc: 3
  upscale: 4
  • 由于我们是重新训练网络,因此预训练模型路径pretrain_network_g: ~

训练

训练之前还需要做一些额外的修改。可能是由于我没有以python setup.py develop之类的方式安装basicsr(官方给的安装教程最后有这一步),所以我项目里面缺少了version相关的文件。所以需要对相应的地方进行注释。主要有两处

  1. basicsr/__init__.py注释掉最后一行
# https://github.com/xinntao/BasicSR
# flake8: noqa
from .archs import *
from .data import *
from .losses import *
from .metrics import *
from .models import *
from .ops import *
from .test import *
from .train import *
from .utils import *
# from .version import __gitsha__, __version__
  1. basicsr/utils/logger.py注释掉第196行和第210行
def get_env_info():
    """Get environment information.

    Currently, only log the software version.
    """
    import torch
    import torchvision

    # from basicsr.version import __version__
    msg = r"""
                ____                _       _____  ____
               / __ ) ____ _ _____ (_)_____/ ___/ / __ \
              / __  |/ __ `// ___// // ___/\__ \ / /_/ /
             / /_/ // /_/ /(__  )/ // /__ ___/ // _, _/
            /_____/ \__,_//____//_/ \___//____//_/ |_|
     ______                   __   __                 __      __
    / ____/____   ____   ____/ /  / /   __  __ _____ / /__   / /
   / / __ / __ \ / __ \ / __  /  / /   / / / // ___// //_/  / /
  / /_/ // /_/ // /_/ // /_/ /  / /___/ /_/ // /__ / /<    /_/
  \____/ \____/ \____/ \____/  /_____/\____/ \___//_/|_|  (_)
    """
    msg += ('\nVersion Information: '
            # f'\n\tBasicSR: {__version__}'
            f'\n\tPyTorch: {torch.__version__}'
            f'\n\tTorchVision: {torchvision.__version__}')
    return msg

在执行训练命令以前,我们需要先cd到BasicSR路径下,确保当前路径是BasicSR,例如我这里

D:\Project\Python\BasicSR>

由于是在win系统下,运行命令时为了防止python文件找不到库,我们还需要对basicsr/train.py稍作修改,添加代码使得以当前BasicSR路径为基准,修改如下

import datetime
import logging
import math
import time
import torch
from os import path as osp

# for windows:
import sys
sys.path.extend(['D:/Project/Python/BasicSR'])

from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
                           init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
from basicsr.utils.options import copy_opt_file, dict2str, parse_options

最后在终端运行命令python ./basicsr/train.py -opt options/train/IMDN/train_IMDN_x4.yml启动训练。


最后感谢大佬们的贡献。
关于验证请参考Windows下使用BasicSR测试自定义图像超分模型,另外这里贴上两个tips仅供自己记录。
设置pycharm打开terminal终端,自动进入虚拟环境的办法
【工具篇】如何优雅地监控显卡(GPU)使用情况?


后记:

  • 遇到"No object named ‘XXX’ found in ‘arch’ registry!"问题,可以参考https://github.com/XPixelGroup/BasicSR/issues/506,博主这里的情况是basicsr的环境又pip安装了basicsr库,所以只需要卸载basicsr库即可,即pip uninstall basicsr
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐