본문 바로가기

연구실 공부

Image deblurring using DeblurGAN(base_model.py, models.py, test_models.py)

728x90

https://github.com/KupynOrest/DeblurGAN

 

GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

Image Deblurring using Generative Adversarial Networks - GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

github.com

https://arxiv.org/abs/1711.07064

 

DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks

We present DeblurGAN, an end-to-end learned method for motion deblurring. The learning is based on a conditional GAN and the content loss . DeblurGAN achieves state-of-the art performance both in the structural similarity measure and visual appearance. The

arxiv.org

토대로 공부하고 작성했습니다.

 

저번 글에서 data 폴더에 있는 코드들을 살펴봤습니다. 해당 폴더에는 data(image)를 불러오는 클래스들을 options에 맞춰서 생성하고 이용해 data를 불러왔습니다.

이번에는 models 폴더에 존재하는 코드들 중 base_model.py, models.py, test_model.py에 대해서 살펴보겠습니다.

먼저 base_model.py를 보겠습니다.

 

import os
import torch


class BaseModel():
    def name(self):
        return 'BaseModel'

    def __init__(self, opt):
        # options에 맞춰서 각 변수들 설정
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # used in test time, no backprop
    def test(self):
        pass

    def get_image_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        # 학습 결과를 저장하는 메서드
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda(device=gpu_ids[0])
        

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label):
        # 저장된 학습 결과를 불러오는 메서드
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        network.load_state_dict(torch.load(save_path))

    def update_learning_rate():
        pass

 

위 코드는 간단하게 BaseModel을 option에 맞춰서 생성합니다. 학습 결과를 저장하고 불러오는 메서드도 존재합니다.

그다음 model.py를 보면

 

from .conditional_gan_model import ConditionalGAN

def create_model(opt):
   model = None
   if opt.model == 'test':
      # TEST#assert (opt.dataset_mode == 'single')
      from .test_model import TestModel
      model = TestModel( opt )
   else:
      model = ConditionalGAN(opt)
   # model.initialize(opt)
   print("model [%s] was created" % (model.name()))
   return model

 

이와 같이 코드가 작성되어 있습니다. 주석 처리된 asssert (opt.dataset_model == 'single')은 opt.dataset_mode에서 single이 test를 위해 만들어진 부분이고, opt.model이 test일 때 single을 사용하지 않으면 에러를 출력해주는 코드입니다. opt.model이 test인지 아닌지에 따라 model의 생성 방식이 TestModel과 ConditionalGAN으로 나눠집니다.

마지막으로 test_model.py를 보겠습니다.

 

import torch
from torch.autograd import Variable
from collections import OrderedDict
import util.util as util
from .base_model import BaseModel
from . import networks


class TestModel(BaseModel):
    def name(self):
        return 'TestModel'

    def __init__(self, opt):
        assert(not opt.isTrain)
        super(TestModel, self).__init__(opt)
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)

        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, False,
                                      opt.learn_residual)
        # networks.define_G를 통해 netG를 생성하는데, GAN에서 generator의 network를 매개변수 값들을 통해 생성하는 코드입니다.
        which_epoch = opt.which_epoch
        self.load_network(self.netG, 'G', which_epoch)
        # 학습한 결과를 불러오는 코드

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        print('-----------------------------------------------')
 
    def set_input(self, input):
        # we need to use single_dataset mode
        input_A = input['A']
        temp = self.input_A.clone()
        temp.resize_(input_A.size()).copy_(input_A)
        self.input_A = temp
        self.image_paths = input['A_paths']

    

    def test(self):
        with torch.no_grad():
            self.real_A = Variable(self.input_A)
            self.fake_B = self.netG.forward(self.real_A)
            # 자동학습 off, real_A로 netG의 forward를 진행

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        # tesor2im는 인자로 받은 tensor를 image형태로 바꿔주는 메서드
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])

 

test_model.py는 test할 때 사용하는 코드로 assert(not opt.isTrain)을 통해 train이 아닌 것을 확인해줍니다. GAN 구조에서 generator의 network를 생성하고 해당 학습 결과를 불러오는 코드입니다. test 메서드는 input_A를 사용해 generator의 forward를 진행해 가짜 이미지 fake_B를 생성합니다.