본문 바로가기

연구실 공부

Image deblurring using DeblurGAN(실행 결과, 새로운 unet 작성해보기)

728x90

https://github.com/KupynOrest/DeblurGAN

 

GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

Image Deblurring using Generative Adversarial Networks - GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

github.com

https://arxiv.org/abs/1711.07064

 

DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks

We present DeblurGAN, an end-to-end learned method for motion deblurring. The learning is based on a conditional GAN and the content loss . DeblurGAN achieves state-of-the art performance both in the structural similarity measure and visual appearance. The

arxiv.org

토대로 공부하고 작성했습니다.

 

저번 글에서 train과 test까지 모든 코드를 살펴봤습니다. 이번에는 image 데이터를 직접 학습시켜 실행 결과를 확인해보겠습니다. 그리고 unet을 새로 만들어 적용해 학습 결과도 살펴보겠습니다.

학습을 위해 train과 test에서 파일을 불러오는 주소나 다른 조건들을 변경하고 진행하겠습니다. 이미지는 200장만 사용하겠습니다. 저는 gaussian blur 처리된 이미지를 통해 학습을 시킨 후 원본 이미지와 generator가 생성한 이미지의 PSNR을 측정하는 방식으로 정확도를 측정했습니다.

 

netG를 resnset_9blocks으로 설정하고 content loss는 L2 loss를 사용해 학습을 한 결과

 

이와 같이 PSNR을 얻을 수 있습니다. 이를 토대로 test 한 결과

 

이와 같은 PSNR을 얻습니다. 이미지 데이터셋이 200장밖에 사용하지 않았고 gaussian blur 처리가 너무 심하게 되어 생각보다 낮은 정확도를 보입니다. GAN 모델의 종류와 content loss의 종류에 따라 PSNR의 결과도 다르게 나옵니다.

 

이번에는 직접 간단한 unet을 생성해 적용해보겠습니다.

 

def UpConv(input_nc, output_nc, func):
    model = nn.Sequential(
        nn.ConvTranspose2d(input_nc, output_nc, kernel_size = 3, stride = 2, padding=1, output_padding = 1),
        nn.BatchNorm2d(output_nc),
        func,
    )
    return model

def Maxpool():
    pool = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0)
    return pool

def Conv(input_nc, output_nc, func):
    model = nn.Sequential(
        nn.Conv2d(input_nc, output_nc, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(output_nc),
        func,
        nn.Conv2d(output_nc, output_nc, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(output_nc),
        func,
    )
    return model

class UNet(nn.Module):

    def __init__(self, input_nc, output_nc, num_filter):
        super(UNet, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.num_filter = num_filter
        func = nn.LeakyReLU(0.2, inplace=True)

        self.down_1 = Conv(self.input_nc, self.num_filter, func)
        self.pool_1 = Maxpool()

        self.down_2 = Conv(self.num_filter, self.num_filter * 2, func)
        self.pool_2 = Maxpool()

        self.down_3 = Conv(self.num_filter * 2, self.num_filter * 4, func)
        self.pool_3 = Maxpool()

        self.down_4 = Conv(self.num_filter * 4, self.num_filter * 8, func)
        self.pool_4 = Maxpool()

        self.bridge = Conv(self.num_filter * 8, self.num_filter * 16, func)

        self.trans_1 = UpConv(self.num_filter * 16, self.num_filter * 8, func)
        self.up_1 = Conv(self.num_filter * 16, self.num_filter * 8, func)

        self.trans_2 = UpConv(self.num_filter * 8, self.num_filter * 4, func)
        self.up_2 = Conv(self.num_filter * 8, self.num_filter * 4, func)

        self.trans_3 = UpConv(self.num_filter * 4, self.num_filter * 2, func)
        self.up_3 = Conv(self.num_filter * 4, self.num_filter * 2, func)

        self.trans_4 = UpConv(self.num_filter * 2, self.num_filter, func)
        self.up_4 = Conv(self.num_filter * 2, self.num_filter, func)

        self.out = nn.Sequential(
            nn.Conv2d(self.num_filter, self.output_nc, kernel_size = 3, stride = 1, padding = 1),
            func,
        )

    def forward(self, input):
        down_1 = self.down_1(input)
        pool_1 = self.pool_1(down_1)

        down_2 = self.down_2(pool_1)
        pool_2 = self.pool_2(down_2)

        down_3 = self.down_3(pool_2)
        pool_3 = self.pool_3(down_3)

        down_4 = self.down_4(pool_3)
        pool_4 = self.pool_4(down_4)

        bridge = self.bridge(pool_4)

        trans_1 = self.trans_1(bridge)
        concat_1 = torch.cat([trans_1, down_4], dim=1)
        up_1 = self.up_1(concat_1)

        trans_2 = self.trans_2(up_1)
        concat_2 = torch.cat([trans_2, down_3], dim=1)
        up_2 = self.up_2(concat_2)

        trans_3 = self.trans_3(up_2)
        concat_3 = torch.cat([trans_3, down_2], dim=1)
        up_3 = self.up_3(concat_3)

        trans_4 = self.trans_4(up_3)
        concat_4 = torch.cat([trans_4, down_1], dim=1)
        up_4 = self.up_4(concat_4)

        out = self.out(up_4)
        return out

 

이와 같이 기본 unet을 작성했습니다. 이 코드를 network.py에 올바른 위치에 추가해주고 조건문에도 추가를 해주면 올바르게 학습할 수 있습니다.