본문 바로가기

연구실 공부

Image deblurring using DeblurGAN(network, train, test)

728x90

https://github.com/KupynOrest/DeblurGAN

 

GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

Image Deblurring using Generative Adversarial Networks - GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

github.com

https://arxiv.org/abs/1711.07064

 

DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks

We present DeblurGAN, an end-to-end learned method for motion deblurring. The learning is based on a conditional GAN and the content loss . DeblurGAN achieves state-of-the art performance both in the structural similarity measure and visual appearance. The

arxiv.org

 

토대로 공부하고 작성했습니다.

 

저번 글에서 conditional_gan_model과 losses에 대해서 알아봤습니다. conditional_gan_model에서는 opt.model의 값과 gan의 종류에 따라 맞춰 generator와 discriminator의 network를 생성하고 최적화 함수를 선언했습니다. losses에서는 각 경우에 맞춰 loss를 정의해주는 기능을 했습니다.

이번에는 network, test, train에 대해서 알아보겠습니다. 먼저 network를 보겠습니다.

 

import torch
import torch.nn as nn
# from torch.nn import init
import functools
# from torch.autograd import Variable
import numpy as np


###############################################################################
# Functions
###############################################################################


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0)
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer
    # norm_type에 따라 norm_layer을 생성


def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], use_parallel=True,
             learn_residual=False):
    # 매개변수들을 통해 generator의 network를 생성하는 메서드
    netG = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert (torch.cuda.is_available())

    if which_model_netG == 'resnet_9blocks':
        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
                               gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
    elif which_model_netG == 'resnet_6blocks':
        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
                               gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
    elif which_model_netG == 'unet_128':
        netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
    elif which_model_netG == 'unet_256':
        netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
    if len(gpu_ids) > 0:
        netG.cuda(gpu_ids[0])
    netG.apply(weights_init)
    return netG
    # netG를 return 하는데 which_model_netG의 값에 따라 network 구조가 달라진다


def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[],
             use_parallel=True):
    # 매개변수들을 토애 discriminator의 network를 생성하는 메서드
    netD = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert (torch.cuda.is_available())
    if which_model_netD == 'basic':
        netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
                                   gpu_ids=gpu_ids, use_parallel=use_parallel)
    elif which_model_netD == 'n_layers':
        netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
                                   gpu_ids=gpu_ids, use_parallel=use_parallel)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD)
    if use_gpu:
        netD.cuda(gpu_ids[0])
    netD.apply(weights_init)
    return netD
    # netD를 return 하는데 which_model_netD의 값에 따라 network 구조가 달라진다


def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)


##############################################################################
# Classes
##############################################################################


# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
    # which_model_netG의 값이 'resnet_9blocks', 'resnet_6blocks'인 경우 해당 클래스로 넘어와 generator의 network 생성
    def __init__(
            self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
            n_blocks=6, gpu_ids=[], use_parallel=True, learn_residual=False, padding_type='reflect'):
        assert (n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.gpu_ids = gpu_ids
        self.use_parallel = use_parallel
        self.learn_residual = learn_residual

        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [
            nn.ReflectionPad2d(3),
            # 모든 경계에 padding = 3으로 진행, padding 내용은 데이터의 경계의 반사한 값이 들어간다
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
            norm_layer(ngf),
            nn.ReLU(True)
        ]

        n_downsampling = 2

        # 下采样
        # for i in range(n_downsampling): # [0,1]
        #  mult = 2**i
        #
        #  model += [
        #     nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
        #     norm_layer(ngf * mult * 2),
        #     nn.ReLU(True)
        #  ]

        model += [
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias),
            norm_layer(128),
            nn.ReLU(True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias),
            norm_layer(256),
            nn.ReLU(True)
        ]

        # 中间的残差网络
        # mult = 2**n_downsampling
        for i in range(n_blocks):
            # model += [
            #  ResnetBlock(
            #     ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
            #     use_dropout=use_dropout, use_bias=use_bias)
            # ]
            model += [
                ResnetBlock(256, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)
            ]
            # model에 아래에 존재하는 ResnetBlock의 내용을 추가한다.
            # resnet_9blocks인 경우 9번 반복, resnet_6blocks인 경우 6번 반복

        # 上采样
        # for i in range(n_downsampling):
        #  mult = 2**(n_downsampling - i)
        #
        #  model += [
        #     nn.ConvTranspose2d(
        #        ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2,
        #        padding=1, output_padding=1, bias=use_bias),
        #     norm_layer(int(ngf * mult / 2)),
        #     nn.ReLU(True)
        #  ]
        model += [
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
            norm_layer(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
            norm_layer(64),
            nn.ReLU(True),
        ]

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)
        # 위 과정을 거쳐 network를 쌓아서 model에 저장

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
            output = nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            output = self.model(input)
        if self.learn_residual:
            # output = input + output
            output = torch.clamp(input + output, min=-1, max=1)
        return output
        # 위에서 만든 model에 input을 넣어 진행하고 learn_residual 값에 따라 결과 값을 -1 ~ 1사이로 변경한다


# Define a resnet block
class ResnetBlock(nn.Module):

   def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
      super(ResnetBlock, self).__init__()

      padAndConv = {
         'reflect': [
                nn.ReflectionPad2d(1),
                nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
         'replicate': [
                nn.ReplicationPad2d(1),
                nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
         'zero': [
                nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
      }
        # dictionary 형태로 저장

      try:
         blocks = padAndConv[padding_type] + [
            norm_layer(dim),
            nn.ReLU(True)
            ] + [
            nn.Dropout(0.5)
         ] if use_dropout else [] + padAndConv[padding_type] + [
            norm_layer(dim)
         ]
      except:
         raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        # (default) padding_type = reflect

      self.conv_block = nn.Sequential(*blocks)

      # self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
      # def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
      #     padAndConv = {
      #         'reflect': [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
      #         'replicate': [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
      #         'zero': [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
      #     }
      #     try:
      #         blocks = [
      #             padAndConv[padding_type],
      #
      #             norm_layer(dim),
      #             nn.ReLU(True),
      #             nn.Dropout(0.5) if use_dropout else None,
      #
      #             padAndConv[padding_type],
      #
      #             norm_layer(dim)
      #         ]
      #     except:
      #         raise NotImplementedError('padding [%s] is not implemented' % padding_type)
      #
      #     return nn.Sequential(*blocks)

      # blocks = []
      # if padding_type == 'reflect':
      #  blocks += [nn.ReflectionPad2d(1),  nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)]
      # elif padding_type == 'replicate':
      #  blocks += [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)]
      # elif padding_type == 'zero':
      #  blocks += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
      # else:
      #  raise NotImplementedError('padding [%s] is not implemented' % padding_type)
      #
      # blocks += [
      #  norm_layer(dim),
      #  nn.ReLU(True),
      #  nn.Dropout(0.5) if use_dropout else None
      # ]
      #
      # if padding_type == 'reflect':
      #  blocks += [nn.ReflectionPad2d(1),  nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)]
      # elif padding_type == 'replicate':
      #  blocks += [nn.ReplicationPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)]
      # elif padding_type == 'zero':
      #  blocks += [nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
      # else:
      #  raise NotImplementedError('padding [%s] is not implemented' % padding_type)
      #
      # blocks += [
      #  norm_layer(dim)
      # ]
      #
      # return nn.Sequential(*blocks)

   def forward(self, x):
      out = x + self.conv_block(x)
      return out


# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
    # which_model_netG의 값이 'unet_128', 'unet_256'인 경우 해당 클래스로 넘어와 generator의 network 생성
    def __init__(
            self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d,
            use_dropout=False, gpu_ids=[], use_parallel=True, learn_residual=False):
        super(UnetGenerator, self).__init__()
        self.gpu_ids = gpu_ids
        self.use_parallel = use_parallel
        self.learn_residual = learn_residual
        # currently support only input_nc == output_nc
        assert (input_nc == output_nc)

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer,
                                                 use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)
        # 위 과정을 거쳐 unet_block을 생성한다
        self.model = unet_block

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
            output = nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            output = self.model(input)
        if self.learn_residual:
            output = input + output
            output = torch.clamp(output, min=-1, max=1)
        return output
        # 위에서 생성한 model에 input을 넣어 output을 구하고 learn_residual에 따라 연산을 더 한 다음 return 한다.


# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
    # UnetGenerator에서 호출해 UnetBlock의 내용을 채우는 class
    def __init__(
            self, outer_nc, inner_nc, submodule=None,
            outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        dConv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        dRelu = nn.LeakyReLU(0.2, True)
        dNorm = norm_layer(inner_nc)
        uRelu = nn.ReLU(True)
        uNorm = norm_layer(outer_nc)
        # 위와 같이 각 변수에 들어가는 연산들을 저장한다

        if outermost:
            uConv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            dModel = [dConv]
            uModel = [uRelu, uConv, nn.Tanh()]
            model = [
                dModel,
                submodule,
                uModel
            ]
        # model = [
        #  # Down
        #  nn.Conv2d( outer_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias),
        #
        #  submodule,
        #  # Up
        #  nn.ReLU(True),
        #  nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1),
        #  nn.Tanh()
        # ]
        elif innermost:
            uConv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            dModel = [dRelu, dConv]
            uModel = [uRelu, uConv, uNorm]
            model = [
                dModel,
                uModel
            ]
        # model = [
        #  # down
        #  nn.LeakyReLU(0.2, True),
        #  # up
        #  nn.ReLU(True),
        #  nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias),
        #  norm_layer(outer_nc)
        # ]
        else:
            uConv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            dModel = [dRelu, dConv, dNorm]
            uModel = [uRelu, uConv, uNorm]

            model = [
                dModel,
                submodule,
                uModel
            ]
            model += [nn.Dropout(0.5)] if use_dropout else []

        # if use_dropout:
        #  model = down + [submodule] + up + [nn.Dropout(0.5)]
        # else:
        #  model = down + [submodule] + up

        self.model = nn.Sequential(*model)
        # 조건무에 맞춰 들어가 model에 내용들을 추가한다 submodule은 처음 호출되는 경우를 제외하고 전에 했던 연산이 들어온다.
        # 계속 model에 들어가는 연산들이 누적되면서 하나의 큰 network가 생성된다

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([self.model(x), x], 1)
            # unet구조에서 concat 부분을 진행하는 코드


# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
    # which_model_netD의 값에 따라 discriminator의 network를 생성하는 class
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[],
                 use_parallel=True):
        super(NLayerDiscriminator, self).__init__()
        self.gpu_ids = gpu_ids
        self.use_parallel = use_parallel

        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = int(np.ceil((kw - 1) / 2))
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
        
        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)
        # discriminator의 network를 만들어 model에 저장
        
    def forward(self, input):
        if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)
        # discriminator의 network에 input을 넣어 forward 진행

 

generator와 discriminator의 network를 구체적으로 정의해주는 코드입니다. which_model_netG, which_model_netD의 값에 따라 알맞게 network를 생성해줍니다. generator의 network는 크게 reset과 unet으로 나눠지고 rsenet_9blocks, resnet_6blocks, unet_128, unet_256이 존재합니다.

 

이번에는 train코드를 보겠습니다.

 

import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from util.metrics import PSNR, SSIM
from multiprocessing import freeze_support

def train(opt, data_loader, model, visualizer):
   dataset = data_loader.load_data()
   dataset_size = len(data_loader)
   print('#training images = %d' % dataset_size)
   total_steps = 0
   for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
      epoch_start_time = time.time()
      epoch_iter = 0
      for i, data in enumerate(dataset):
         iter_start_time = time.time()
         total_steps += opt.batchSize
         epoch_iter += opt.batchSize
         model.set_input(data)
         model.optimize_parameters()
      
         if total_steps % opt.display_freq == 0:
            results = model.get_current_visuals()
            psnrMetric = PSNR(results['Restored_Train'], results['Sharp_Train'])
            print('PSNR on Train = %f' % psnrMetric)
            visualizer.display_current_results(results, epoch)
            # 정화도를 PSNR을 이용해서 표현하는데 PSNR의 값이 클 수록 학습이 잘 진행된다는 것을 의미

         if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

         if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.save('latest')

      if epoch % opt.save_epoch_freq == 0:
         print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
         model.save('latest')
         model.save(epoch)
         # 학습 결과 저장
      
      print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
      # 학습 시간 표현
      if epoch > opt.niter:
         model.update_learning_rate()


if __name__ == '__main__':
   freeze_support()

   # python train.py --dataroot /.path_to_your_data --learn_residual --resize_or_crop crop --fineSize CROP_SIZE (we used 256)

   opt = TrainOptions().parse()
   opt.dataroot = 'D:\Photos\TrainingData\BlurredSharp\combined'
   opt.learn_residual = True
   opt.resize_or_crop = "crop"
   opt.fineSize = 256
   opt.gan_type = "gan"
   opt.which_model_netG = "unet_256"
   # gan_type = gan, op.which_model_netG = unet_256
   
   default = 5000
   opt.save_latest_freq = 100

   opt.print_freq = 20

   data_loader = CreateDataLoader(opt) # dataloader 생성
   model = create_model(opt) # options에 맞춰서 model 생성
   visualizer = Visualizer(opt) # visualizer 생성
   train(opt, data_loader, model, visualizer) # train함수 실행

 

train 코드는 간단합니다. option에 맞춰 data_loader, model, visualizer를 생성한 후 train을 진행합니다. 위 코드는 unet_256을 사용했지만 unet_128, resnet_9blocks를 사용해도 상관없습니다. 결과는 PSNR로 표현했는데 이는 값이 클수록 학습 정확도가 높다는 것을 의미합니다.

 

마지막으로 test 코드를 보겠습니다.

 

import time
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from pdb import set_trace as st
from util import html
from util.metrics import PSNR
from ssim import SSIM
from util.metrics import PSNR, SSIM
from PIL import Image

from multiprocessing import freeze_support

if __name__ == '__main__':
  freeze_support()
  
  opt = TestOptions().parse()
  opt.nThreads = 1   # test code only supports nThreads = 1
  opt.batchSize = 1  # test code only supports batchSize = 1
  opt.serial_batches = True  # no shuffle
  opt.no_flip = True  # no flip

  data_loader = CreateDataLoader(opt)
  dataset = data_loader.load_data()
  model = create_model(opt)
  visualizer = Visualizer(opt)
  # option에 맞춰 data_loader, model, visualizer 생성
  
# create website
  web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
  webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
  avgPSNR = 0.0
  avgSSIM = 0.0
  counter = 0;  

  for i, data in enumerate(dataset):
     if i >= opt.how_many:
        break
     counter = i
     model.set_input(data)
     model.test()

     visuals = model.get_current_visuals()
     pilFake = Image.fromarray(visuals['fake_B'])
     pilReal = Image.fromarray(visuals['real_A'])
     # avgPSNR += PSNR(visuals['fake_B'],visuals['real_A'])
     # pilFake = Image.fromarray(visuals['fake_B'])
     # pilReal = Image.fromarray(visuals['real_B'])
     # avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal)
     img_path = model.get_image_paths()
     print('process image... %s' % img_path)
     visualizer.save_images(webpage, visuals, img_path)

  #avgPSNR /= counter
  #avgSSIM /= counter
  #print('PSNR = %f, SSIM = %f' %
  #               (avgPSNR, avgSSIM))

  webpage.save()

 

option에 맞게 model, dataloader, visualizer를 생성합니다. 여기서 중요한 것은 train할 때 사용한 netG와 netD의 형태가 동일하게 설정되어야 학습된 결과를 불러와 정확도를 측정할 수 있습니다.