https://github.com/KupynOrest/DeblurGAN
https://arxiv.org/abs/1711.07064
토대로 공부하고 작성했습니다.
저번 글에서 data 폴더에 있는 코드들을 살펴봤습니다. 해당 폴더에는 data(image)를 불러오는 클래스들을 options에 맞춰서 생성하고 이용해 data를 불러왔습니다.
이번에는 models 폴더에 존재하는 코드들 중 base_model.py, models.py, test_model.py에 대해서 살펴보겠습니다.
먼저 base_model.py를 보겠습니다.
import os
import torch
class BaseModel():
def name(self):
return 'BaseModel'
def __init__(self, opt):
# options에 맞춰서 각 변수들 설정
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
# 학습 결과를 저장하는 메서드
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda(device=gpu_ids[0])
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
# 저장된 학습 결과를 불러오는 메서드
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.load_state_dict(torch.load(save_path))
def update_learning_rate():
pass
위 코드는 간단하게 BaseModel을 option에 맞춰서 생성합니다. 학습 결과를 저장하고 불러오는 메서드도 존재합니다.
그다음 model.py를 보면
from .conditional_gan_model import ConditionalGAN
def create_model(opt):
model = None
if opt.model == 'test':
# TEST#assert (opt.dataset_mode == 'single')
from .test_model import TestModel
model = TestModel( opt )
else:
model = ConditionalGAN(opt)
# model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model
이와 같이 코드가 작성되어 있습니다. 주석 처리된 asssert (opt.dataset_model == 'single')은 opt.dataset_mode에서 single이 test를 위해 만들어진 부분이고, opt.model이 test일 때 single을 사용하지 않으면 에러를 출력해주는 코드입니다. opt.model이 test인지 아닌지에 따라 model의 생성 방식이 TestModel과 ConditionalGAN으로 나눠집니다.
마지막으로 test_model.py를 보겠습니다.
import torch
from torch.autograd import Variable
from collections import OrderedDict
import util.util as util
from .base_model import BaseModel
from . import networks
class TestModel(BaseModel):
def name(self):
return 'TestModel'
def __init__(self, opt):
assert(not opt.isTrain)
super(TestModel, self).__init__(opt)
self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, False,
opt.learn_residual)
# networks.define_G를 통해 netG를 생성하는데, GAN에서 generator의 network를 매개변수 값들을 통해 생성하는 코드입니다.
which_epoch = opt.which_epoch
self.load_network(self.netG, 'G', which_epoch)
# 학습한 결과를 불러오는 코드
print('---------- Networks initialized -------------')
networks.print_network(self.netG)
print('-----------------------------------------------')
def set_input(self, input):
# we need to use single_dataset mode
input_A = input['A']
temp = self.input_A.clone()
temp.resize_(input_A.size()).copy_(input_A)
self.input_A = temp
self.image_paths = input['A_paths']
def test(self):
with torch.no_grad():
self.real_A = Variable(self.input_A)
self.fake_B = self.netG.forward(self.real_A)
# 자동학습 off, real_A로 netG의 forward를 진행
# get image paths
def get_image_paths(self):
return self.image_paths
def get_current_visuals(self):
real_A = util.tensor2im(self.real_A.data)
fake_B = util.tensor2im(self.fake_B.data)
# tesor2im는 인자로 받은 tensor를 image형태로 바꿔주는 메서드
return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
test_model.py는 test할 때 사용하는 코드로 assert(not opt.isTrain)을 통해 train이 아닌 것을 확인해줍니다. GAN 구조에서 generator의 network를 생성하고 해당 학습 결과를 불러오는 코드입니다. test 메서드는 input_A를 사용해 generator의 forward를 진행해 가짜 이미지 fake_B를 생성합니다.
'연구실 공부' 카테고리의 다른 글
Image deblurring using DeblurGAN(network, train, test) (0) | 2022.03.10 |
---|---|
Image deblurring using DeblurGAN(conditional_gan_model, losses.py) (0) | 2022.03.09 |
Image deblurring using DeblurGAN(data) (0) | 2022.03.09 |
Image deblurring using DeblurGAN(options, datasets) (0) | 2022.03.09 |
Image deblurring using DeblurGAN (0) | 2022.03.07 |