본문 바로가기

연구실 공부

Image deblurring using DeblurGAN(conditional_gan_model, losses.py)

728x90

https://github.com/KupynOrest/DeblurGAN

 

GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

Image Deblurring using Generative Adversarial Networks - GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

github.com

https://arxiv.org/abs/1711.07064

 

DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks

We present DeblurGAN, an end-to-end learned method for motion deblurring. The learning is based on a conditional GAN and the content loss . DeblurGAN achieves state-of-the art performance both in the structural similarity measure and visual appearance. The

arxiv.org

토대로 공부하고 작성했습니다.

 

저번 글에서는 base_model, models, test_model에 대해서 알아봤습니다. 학습이나 test 할 때 사용할 model을 생성하는 코드였습니다. model을 생성하면 조건에 맞게 image를 불러오거나 test를 진행하는 메서드 등이 존재합니다.

이번에는 conditional_gan_model, losses에 대해서 알아보겠습니다.

먼저 conditional_gan_model.py 코드를 보겠습니다.

 

import numpy as np
import torch
import os
from collections import OrderedDict
from torch.autograd import Variable
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
from .losses import init_loss

try:
   xrange          # Python2
except NameError:
   xrange = range  # Python 3

class ConditionalGAN(BaseModel):
   def name(self):
      return 'ConditionalGANModel'

   def __init__(self, opt):
      super(ConditionalGAN, self).__init__(opt)
      self.isTrain = opt.isTrain
      # define tensors
      self.input_A = self.Tensor(opt.batchSize, opt.input_nc,  opt.fineSize, opt.fineSize)
      self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)

      # load/define networks
      # Temp Fix for nn.parallel as nn.parallel crashes oc calculating gradient penalty
      use_parallel = not opt.gan_type == 'wgan-gp'
      # use_parallel에 opt.gan_type에 따라 값이 들어갑니다
      print("Use Parallel = ", "True" if use_parallel else "False")
      self.netG = networks.define_G(
         opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm,
         not opt.no_dropout, self.gpu_ids, use_parallel, opt.learn_residual
      )
      # GAN 구조에서 generator의 network를 매개변수를 사용해 생성하는 코드

      if self.isTrain:
         use_sigmoid = opt.gan_type == 'gan'
         self.netD = networks.define_D(
            opt.output_nc, opt.ndf, opt.which_model_netD,
            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, use_parallel
         )
         # isTrain이 true일 경우 GAN 구조에서 discriminator의 network를 매개변수를 사용해 생성하는 코드

      if not self.isTrain or opt.continue_train:
         self.load_network(self.netG, 'G', opt.which_epoch)
         if self.isTrain:
            self.load_network(self.netD, 'D', opt.which_epoch)
      # opt.continue_train과 istrain의 값에 따라 조건문에 들어가 generator의 network(학습결과)와 discriminator의 network(학습결과)를 불러온다

      if self.isTrain:
         self.fake_AB_pool = ImagePool(opt.pool_size)
         self.old_lr = opt.lr

         # initialize optimizers
         self.optimizer_G = torch.optim.Adam( self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999) )
         self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999) )
         # generator와 discriminator의 최적화 함수를 adam으로 설정

         self.criticUpdates = 5 if opt.gan_type == 'wgan-gp' else 1
         
         # define loss functions
         self.discLoss, self.contentLoss = init_loss(opt, self.Tensor)
         # init_loss 메서드를 이용해 discLoss, contentLoss를 저장
         # gan에 따라 discLoss 달라지고 opt.model에 따라 contentLoss가 달라진다

      print('---------- Networks initialized -------------')
      networks.print_network(self.netG)
      if self.isTrain:
         networks.print_network(self.netD)
      print('-----------------------------------------------')

   def set_input(self, input):
      AtoB = self.opt.which_direction == 'AtoB'
      inputA = input['A' if AtoB else 'B']
      inputB = input['B' if AtoB else 'A']
      self.input_A.resize_(inputA.size()).copy_(inputA)
      self.input_B.resize_(inputB.size()).copy_(inputB)
      self.image_paths = input['A_paths' if AtoB else 'B_paths']
      # 폴더에서 이미지를 불러오는 메서드

   def forward(self):
      self.real_A = Variable(self.input_A)
      self.fake_B = self.netG.forward(self.real_A)
      # input_A는 원본이미지가 들어가고 generator의 forward를 진행해 fake_B 생성
      self.real_B = Variable(self.input_B)

   # no backprop gradients
   def test(self):
      self.real_A = Variable(self.input_A, volatile=True)
      self.fake_B = self.netG.forward(self.real_A)
      # test하기 위해 원본 이미지 input_A를 이용해 generator의 forward를 진행해 fake_B 생성
      self.real_B = Variable(self.input_B, volatile=True)

   # get image paths
   def get_image_paths(self):
      return self.image_paths

   def backward_D(self):
      self.loss_D = self.discLoss.get_loss(self.netD, self.real_A, self.fake_B, self.real_B)
      self.loss_D.backward(retain_graph=True)
      # discriminator의 backward 진행하는 메서드

   def backward_G(self):
      self.loss_G_GAN = self.discLoss.get_g_loss(self.netD, self.real_A, self.fake_B)
      # Second, G(A) = B
      self.loss_G_Content = self.contentLoss.get_loss(self.fake_B, self.real_B) * self.opt.lambda_A

      self.loss_G = self.loss_G_GAN + self.loss_G_Content
      # self.loss_G = self.loss_G_Content
      # generator의 오차는 gan을 이용한 오차와 content의 오차를 더한 값을 사용한다.
      self.loss_G.backward()
      # generator의 backward 잰행하는 메서드

   def optimize_parameters(self):
      self.forward()
      # forward 진행 후 변수들 최적화 진행

      for iter_d in xrange(self.criticUpdates):
         self.optimizer_D.zero_grad()
         self.backward_D()
         self.optimizer_D.step()

      self.optimizer_G.zero_grad()
      self.backward_G()
      self.optimizer_G.step()
      # discriminator와 generator 최적화 진행
   
   def get_current_errors(self):
      return OrderedDict([('G_GAN', self.loss_G_GAN.item()),
                     ('G_L1', self.loss_G_Content.item()),
                     ('D_real+fake', self.loss_D.item())
                     ])

   def get_current_visuals(self):
      real_A = util.tensor2im(self.real_A.data)
      fake_B = util.tensor2im(self.fake_B.data)
      real_B = util.tensor2im(self.real_B.data)
      return OrderedDict([('Blurred_Train', real_A), ('Restored_Train', fake_B), ('Sharp_Train', real_B)])
      # tensor를 image로 변환하고 순서대로 dictionary 형태로 return

   def save(self, label):
      self.save_network(self.netG, 'G', label, self.gpu_ids)
      self.save_network(self.netD, 'D', label, self.gpu_ids)
      # 학습 결과 저장

   def update_learning_rate(self):
      lrd = self.opt.lr / self.opt.niter_decay
      lr = self.old_lr - lrd
      for param_group in self.optimizer_D.param_groups:
         param_group['lr'] = lr
      for param_group in self.optimizer_G.param_groups:
         param_group['lr'] = lr
      print('update learning rate: %f -> %f' % (self.old_lr, lr))
      self.old_lr = lr

 

생성자 메서드에 generator와 discriminator의 network를 생성하고 최적화 함수도 adam으로 설정합니다. 또한 생성자 메서드 마지막에 loss도 설정하는데 이는 gan의 종류에 따라, opt.model(pix2pix, content_gan)에 따라 contentloss와 discloss의 값이 정해집니다. forward, discriminator backward, generator backward 등 여러 기능을 가진 메서드들도 선언했습니다.

 ※ content_gan인 경우 PerceptualLoss 클래스로 넘어가 contentloss의 값을 생성하고, pix2pix인 경우, L1loss를 이용해 오차를 구합니다.

이번에는 losses.py를 살펴보겠습니다.

 

import torch
import torch.nn as nn
from torch.nn import init
import functools
import torch.autograd as autograd
import numpy as np
import torchvision.models as models
import util.util as util
from util.image_pool import ImagePool
from torch.autograd import Variable
###############################################################################
# Functions
###############################################################################

class ContentLoss:
   def __init__(self, loss):
      self.criterion = loss
         
   def get_loss(self, fakeIm, realIm):
      return self.criterion(fakeIm, realIm)
   # opt.model이 pix2pix일 때 ContentLoss 클래스 진행

class PerceptualLoss():
   # opt.model이 contett_gan일 때 PerceptualLoss 진행

   def contentFunc(self):
      conv_3_3_layer = 14
      cnn = models.vgg19(pretrained=True).features
      # vgg19는 19개의 layer를 갖고 3x3 크기의 필터를 모든 conv 레이어에 사용하는 network
      # cnn에는 vgg19의 network의 feature를 저장합니다.

      cnn = cnn.cuda()
      model = nn.Sequential()
      model = model.cuda()
      for i,layer in enumerate(list(cnn)):
         model.add_module(str(i),layer)
         if i == conv_3_3_layer:
            break
      return model
      # vgg19 중 14층 까지만 model에 저장(conv2d(256, 256, kernel_sizse=(3,3), stride = (1,1), padding = (1,1)))
      
   def __init__(self, loss):
      self.criterion = loss
      self.contentFunc = self.contentFunc()
         
   def get_loss(self, fakeIm, realIm):
      f_fake = self.contentFunc.forward(fakeIm)
      f_real = self.contentFunc.forward(realIm)
      f_real_no_grad = f_real.detach()
      loss = self.criterion(f_fake, f_real_no_grad)
      return loss
      # fakeIm, realIm을 이용해 오차를 구한다. 오차 함수는 인자로 받은 L1loss 또는 MSELoss
      
class GANLoss(nn.Module):
   def __init__(
         self, use_l1=True, target_real_label=1.0,
         target_fake_label=0.0, tensor=torch.FloatTensor):
      super(GANLoss, self).__init__()
      self.real_label = target_real_label
      self.fake_label = target_fake_label
      self.real_label_var = None
      self.fake_label_var = None
      self.Tensor = tensor
      if use_l1:
         self.loss = nn.L1Loss()
         # Mean Absolute Error
      else:
         self.loss = nn.BCELoss()
         # Binary Cross Entropy Loss

   def get_target_tensor(self, input, target_is_real):
      target_tensor = None
      if target_is_real:
         create_label = ((self.real_label_var is None) or
                     (self.real_label_var.numel() != input.numel()))
         if create_label:
            real_tensor = self.Tensor(input.size()).fill_(self.real_label)
            # input.size만큼 tensor를 생성하고 real_label의 값으로 채운다
            self.real_label_var = Variable(real_tensor, requires_grad=False)
         target_tensor = self.real_label_var
      else:
         create_label = ((self.fake_label_var is None) or
                     (self.fake_label_var.numel() != input.numel()))
         if create_label:
            fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
            self.fake_label_var = Variable(fake_tensor, requires_grad=False)
         target_tensor = self.fake_label_var
      return target_tensor

   def __call__(self, input, target_is_real):
      target_tensor = self.get_target_tensor(input, target_is_real)
      return self.loss(input, target_tensor)
      # 위에서 정의한 loss에 input과 target_tensor를 넣어 오차를 return

class DiscLoss:
   def name(self):
      return 'DiscLoss'

   def __init__(self, opt, tensor):
      self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
      # use_l1 = False이면 GANLoss에서 BCELoss 사용
      self.fake_AB_pool = ImagePool(opt.pool_size)
      
   def get_g_loss(self,net, realA, fakeB):
      # First, G(A) should fake the discriminator
      pred_fake = net.forward(fakeB)
      return self.criterionGAN(pred_fake, 1)
      # discriminator의 network를 이용해 fakeB에 대한 예측을 구한다
      # fakeB로 discriminator가 predict 했으니 0이 나와여 좋고 1이랑 비교.
      
   def get_loss(self, net, realA, fakeB, realB):
      # Fake
      # stop backprop to the generator by detaching fake_B
      # Generated Image Disc Output should be close to zero
      self.pred_fake = net.forward(fakeB.detach())
      self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)
      # fakeB라는 가짜 이미지를 discriminator에 넣어 0과 비교해 오차를 구한다.
      # 가짜 이미지가 들어간 경우여서 0으로 predict하는 것이 좋다.

      # Real
      self.pred_real = net.forward(realB)
      self.loss_D_real = self.criterionGAN(self.pred_real, 1)
      # realB라는 real image를 가지고 discriminator가 예측을 진행
      # 예측한 값과 1을 가지고 오차를 구한다. 1에 가까울 수록 판단을 잘 한 disciriminator

      # Combined loss
      self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
      return self.loss_D
      # 가짜 이미지와 진짜 이미지에 대한 오차의 평균을 return
      
class DiscLossLS(DiscLoss):
   def name(self):
      return 'DiscLossLS'

   def __init__(self, opt, tensor):
      super(DiscLoss, self).__init__(opt, tensor)
      # DiscLoss.initialize(self, opt, tensor)
      self.criterionGAN = GANLoss(use_l1=True, tensor=tensor)
      
   def get_g_loss(self,net, realA, fakeB):
      return DiscLoss.get_g_loss(self,net, realA, fakeB)
      # 위에서 정의한 get_g_loss를 사용
      
   def get_loss(self, net, realA, fakeB, realB):
      return DiscLoss.get_loss(self, net, realA, fakeB, realB)
      # 위에서 정의한 get_loss를 사용
      
class DiscLossWGANGP(DiscLossLS):
   # 논문에서는 WGAN-GP를 사용

   def name(self):
      return 'DiscLossWGAN-GP'

   def __init__(self, opt, tensor):
      super(DiscLossWGANGP, self).__init__(opt, tensor)
      # DiscLossLS.initialize(self, opt, tensor)
      self.LAMBDA = 10
      
   def get_g_loss(self, net, realA, fakeB):
      # First, G(A) should fake the discriminator
      self.D_fake = net.forward(fakeB)
      return -self.D_fake.mean()
      # fakeB를 가지고 discriminator의 forward를 진행하고 나온 값의 평균에 -1을 곱해 return
      
   def calc_gradient_penalty(self, netD, real_data, fake_data):
      alpha = torch.rand(1, 1)
      # random한 값 하나를 alpha에 저장
      alpha = alpha.expand(real_data.size())
      # random하게 받은 값을 real_data의 크기만큼 확장
      alpha = alpha.cuda()

      interpolates = alpha * real_data + ((1 - alpha) * fake_data)

      interpolates = interpolates.cuda()
      interpolates = Variable(interpolates, requires_grad=True)
      
      disc_interpolates = netD.forward(interpolates)

      gradients = autograd.grad(
         outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
         create_graph=True, retain_graph=True, only_inputs=True
      )[0]
      # disc_interpolates에 미분 결과 저장, interpolates는 미분될 내용
      gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
      return gradient_penalty
      
   def get_loss(self, net, realA, fakeB, realB):
      self.D_fake = net.forward(fakeB.detach())
      self.D_fake = self.D_fake.mean()
      
      # Real
      self.D_real = net.forward(realB)
      self.D_real = self.D_real.mean()
      # Combined loss
      self.loss_D = self.D_fake - self.D_real
      gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)
      return self.loss_D + gradient_penalty


def init_loss(opt, tensor):
   # disc_loss = None
   # content_loss = None
   
   if opt.model == 'content_gan':
      content_loss = PerceptualLoss(nn.MSELoss())
      # content_loss.initialize(nn.MSELoss())
   elif opt.model == 'pix2pix':
      content_loss = ContentLoss(nn.L1Loss())
      # content_loss.initialize(nn.L1Loss())
   else:
      raise ValueError("Model [%s] not recognized." % opt.model)
   
   if opt.gan_type == 'wgan-gp':
      disc_loss = DiscLossWGANGP(opt, tensor)
   elif opt.gan_type == 'lsgan':
      disc_loss = DiscLossLS(opt, tensor)
   elif opt.gan_type == 'gan':
      disc_loss = DiscLoss(opt, tensor)
   else:
      raise ValueError("GAN [%s] not recognized." % opt.gan_type)
   # disc_loss.initialize(opt, tensor)
   return disc_loss, content_loss
   # opt.model과 opt.gan_type에 따라 disc_loss와 content_loss를 정의해 return한다

 

이와 같이 loss를 정의하는 코드입니다. opt.model과 opt.gan_type에 따라 loss가 다르게 정의되며 논문에서는 wgan_gp를 사용했으므로 DiscLossWGANGP 클래스를 사용해 disc_loss를 정의합니다.