https://github.com/KupynOrest/DeblurGAN
GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks
Image Deblurring using Generative Adversarial Networks - GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks
github.com
https://arxiv.org/abs/1711.07064
DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks
We present DeblurGAN, an end-to-end learned method for motion deblurring. The learning is based on a conditional GAN and the content loss . DeblurGAN achieves state-of-the art performance both in the structural similarity measure and visual appearance. The
arxiv.org
토대로 공부하고 작성했습니다.
저번 글에서 train과 test까지 모든 코드를 살펴봤습니다. 이번에는 image 데이터를 직접 학습시켜 실행 결과를 확인해보겠습니다. 그리고 unet을 새로 만들어 적용해 학습 결과도 살펴보겠습니다.
학습을 위해 train과 test에서 파일을 불러오는 주소나 다른 조건들을 변경하고 진행하겠습니다. 이미지는 200장만 사용하겠습니다. 저는 gaussian blur 처리된 이미지를 통해 학습을 시킨 후 원본 이미지와 generator가 생성한 이미지의 PSNR을 측정하는 방식으로 정확도를 측정했습니다.
netG를 resnset_9blocks으로 설정하고 content loss는 L2 loss를 사용해 학습을 한 결과
이와 같이 PSNR을 얻을 수 있습니다. 이를 토대로 test 한 결과
이와 같은 PSNR을 얻습니다. 이미지 데이터셋이 200장밖에 사용하지 않았고 gaussian blur 처리가 너무 심하게 되어 생각보다 낮은 정확도를 보입니다. GAN 모델의 종류와 content loss의 종류에 따라 PSNR의 결과도 다르게 나옵니다.
이번에는 직접 간단한 unet을 생성해 적용해보겠습니다.
def UpConv(input_nc, output_nc, func):
model = nn.Sequential(
nn.ConvTranspose2d(input_nc, output_nc, kernel_size = 3, stride = 2, padding=1, output_padding = 1),
nn.BatchNorm2d(output_nc),
func,
)
return model
def Maxpool():
pool = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0)
return pool
def Conv(input_nc, output_nc, func):
model = nn.Sequential(
nn.Conv2d(input_nc, output_nc, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(output_nc),
func,
nn.Conv2d(output_nc, output_nc, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(output_nc),
func,
)
return model
class UNet(nn.Module):
def __init__(self, input_nc, output_nc, num_filter):
super(UNet, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.num_filter = num_filter
func = nn.LeakyReLU(0.2, inplace=True)
self.down_1 = Conv(self.input_nc, self.num_filter, func)
self.pool_1 = Maxpool()
self.down_2 = Conv(self.num_filter, self.num_filter * 2, func)
self.pool_2 = Maxpool()
self.down_3 = Conv(self.num_filter * 2, self.num_filter * 4, func)
self.pool_3 = Maxpool()
self.down_4 = Conv(self.num_filter * 4, self.num_filter * 8, func)
self.pool_4 = Maxpool()
self.bridge = Conv(self.num_filter * 8, self.num_filter * 16, func)
self.trans_1 = UpConv(self.num_filter * 16, self.num_filter * 8, func)
self.up_1 = Conv(self.num_filter * 16, self.num_filter * 8, func)
self.trans_2 = UpConv(self.num_filter * 8, self.num_filter * 4, func)
self.up_2 = Conv(self.num_filter * 8, self.num_filter * 4, func)
self.trans_3 = UpConv(self.num_filter * 4, self.num_filter * 2, func)
self.up_3 = Conv(self.num_filter * 4, self.num_filter * 2, func)
self.trans_4 = UpConv(self.num_filter * 2, self.num_filter, func)
self.up_4 = Conv(self.num_filter * 2, self.num_filter, func)
self.out = nn.Sequential(
nn.Conv2d(self.num_filter, self.output_nc, kernel_size = 3, stride = 1, padding = 1),
func,
)
def forward(self, input):
down_1 = self.down_1(input)
pool_1 = self.pool_1(down_1)
down_2 = self.down_2(pool_1)
pool_2 = self.pool_2(down_2)
down_3 = self.down_3(pool_2)
pool_3 = self.pool_3(down_3)
down_4 = self.down_4(pool_3)
pool_4 = self.pool_4(down_4)
bridge = self.bridge(pool_4)
trans_1 = self.trans_1(bridge)
concat_1 = torch.cat([trans_1, down_4], dim=1)
up_1 = self.up_1(concat_1)
trans_2 = self.trans_2(up_1)
concat_2 = torch.cat([trans_2, down_3], dim=1)
up_2 = self.up_2(concat_2)
trans_3 = self.trans_3(up_2)
concat_3 = torch.cat([trans_3, down_2], dim=1)
up_3 = self.up_3(concat_3)
trans_4 = self.trans_4(up_3)
concat_4 = torch.cat([trans_4, down_1], dim=1)
up_4 = self.up_4(concat_4)
out = self.out(up_4)
return out
이와 같이 기본 unet을 작성했습니다. 이 코드를 network.py에 올바른 위치에 추가해주고 조건문에도 추가를 해주면 올바르게 학습할 수 있습니다.
'연구실 공부' 카테고리의 다른 글
[논문]Improving neural networks by preventing co-adaptation of feature detectors (0) | 2022.03.21 |
---|---|
정규화(Normalization) (0) | 2022.03.14 |
Image deblurring using DeblurGAN(network, train, test) (0) | 2022.03.10 |
Image deblurring using DeblurGAN(conditional_gan_model, losses.py) (0) | 2022.03.09 |
Image deblurring using DeblurGAN(base_model.py, models.py, test_models.py) (0) | 2022.03.09 |