본문 바로가기

연구실 공부

Image deblurring using DeblurGAN(data)

728x90

https://github.com/KupynOrest/DeblurGAN

 

GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

Image Deblurring using Generative Adversarial Networks - GitHub - KupynOrest/DeblurGAN: Image Deblurring using Generative Adversarial Networks

github.com

https://arxiv.org/abs/1711.07064

 

DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks

We present DeblurGAN, an end-to-end learned method for motion deblurring. The learning is based on a conditional GAN and the content loss . DeblurGAN achieves state-of-the art performance both in the structural similarity measure and visual appearance. The

arxiv.org

토대로 공부하고 작성했습니다.

 

저번 글에서 datasets의 combine_A_and_B.py와 options에 대해서 살펴봤습니다. 이번에는 data 폴더에 있는 코드들을 살펴보겠습니다.

 

먼저 data_loader.py를 보면

 

def CreateDataLoader(opt):
    from data.custom_dataset_data_loader import CustomDatasetDataLoader
    data_loader = CustomDatasetDataLoader(opt)
    print(data_loader.name())
    # data_loader.initialize(opt)
    return data_loader

 

이와 같이 작성되어 있습니다. 해당 코드에 존재하는 CustomDatasetDataLoader 클래스에 opt라는 options 정보(base_options, test_option 또는 train_options 등 사용자가 정의한 option)가 들어갑니다. 그리고 CustomDatasetDataLoader 클래스는 custom_dataset_data_loader.py에 존재합니다.

custom_dataset_data_loader.py를 살펴보면

 

import torch.utils.data
from data.base_data_loader import BaseDataLoader


def CreateDataset(opt):
    dataset = None
    if opt.dataset_mode == 'aligned': 
        from data.aligned_dataset import AlignedDataset
        dataset = AlignedDataset(opt)
    elif opt.dataset_mode == 'unaligned':
        from data.unaligned_dataset import UnalignedDataset
        dataset = UnalignedDataset()
    elif opt.dataset_mode == 'single':
        from data.single_dataset import SingleDataset
        dataset = SingleDataset()
        dataset.initialize(opt)
    else:
        raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

    # opt.dataset_name에 따라 dataset 생성, opt.dataset_name이 올바르지 않을 경우 error 출력

    print("dataset [%s] was created" % (dataset.name()))
    # dataset.initialize(opt)
    return dataset


class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def __init__(self, opt):
        super(CustomDatasetDataLoader,self).initialize(opt)
        print("Opt.nThreads = ", opt.nThreads)
        self.dataset = CreateDataset(opt) # 위에 존재하는 class 생성
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads)
        )
        # options에 맞는 설정으로 data 불러온다

    def load_data(self):
        return self.dataloader

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)

 

이와 같이 작성되어 있습니다. CumstomDatasetDataLoader 클래스는 CreateDataset 클래스를 이용해 options에 맞춰 dataset을 생성합니다. 그다음 opt에 맞춰서 데이터를 불러옵니다. opt.dataset.mode가 aligned, unaligned, single 세 가지가 존재합니다. aligned는 combine_A_and_B를 통해 blur 처리된 이미지와 원본 이미지를 붙인 데이터를 불러옵니다. unaligned는 train 할 때 A와 B가 다른 폴더에 있는 경우 사용됩니다. 마지막 single은 test를 위해 존재하는데 이는 blur 처리된 이미지 A만 사용할 경우 사용됩니다.

마지막으로 base_dataset.py 코드를 보겠습니다. 이 코드는 single_dataset.py와 unaligned_dataset.py에서 호출됩니다.

 

import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms

class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    # def initialize(self, opt):
    #     pass

def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize_and_crop':
        osize = [opt.loadSizeX, opt.loadSizeY]
        transform_list.append(transforms.Resize(osize, Image.BICUBIC))
        # BICUBIC interpolation(다항식 보간법)을 사용해 resize 적용 후 transform_list에 붙인다
        transform_list.append(transforms.RandomCrop(opt.fineSize))
        # opt.fineSize만큼 랜덤으로 잘라 transforms_list에 붙이는 코드
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.fineSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSizeX)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))

    # opt.resize_or_crop 에 따라 각 조건문 진행
    # opt.resize_or_crop에 맞춰 transform_list에 내용 저장

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def __scale_width(img, target_width):
    ow, oh = img.size
    if (ow == target_width):
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), Image.BICUBIC)

 

위 코드를 이용해 single, unaligned 상황에 맞게 데이터를 불러옵니다.