본문 바로가기

연구실 공부

UNet++ 코드

728x90

저번에는 UNet++ 논문을 보고 공부를 했습니다. 이번에는 논문 내용을 토대로 UNet++를 구현하고 실행해보겠습니다.

논문에서는 총 4가지의 데이터 셋(cell nuclei, colon polyp, liver, lung nodule)으로 실험을 진행했습니다. 논문을 바탕으로 작성한 UNet++ model을 GAN 구조에서 generator의 network로 사용해보겠습니다.

먼저 UNet++ model을 구현해보겠습니다.

 

class conv_block_nested(nn.Module):

    def __init__(self, in_ch, mid_ch, out_ch):
        super(conv_block_nested, self).__init__()
        self.activation = nn.LeakyReLU(negative_slope = 0.15, inplace=True)
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        output = self.activation(x)

        return output

class Nested_UNet(nn.Module):

    def __init__(self, in_ch=3, out_ch=3):
        super(Nested_UNet, self).__init__()

        n1 = 32
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16, n1 * 32]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
        # 컬러 이미지가 들어가 in_ch = 3, 첫 convolution을 통해 3 -> 32
        
        self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
        self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
        self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
        self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])
        self.conv5_0 = conv_block_nested(filters[4], filters[5], filters[5])

        # 여기서부터 UNet++에서 dense shortcut 구현
        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
        self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])
        self.conv4_1 = conv_block_nested(filters[4] + filters[5], filters[4], filters[4])

        self.conv0_2 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0])
        self.conv1_2 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1])
        self.conv2_2 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2])
        self.conv3_2 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3])

        self.conv0_3 = conv_block_nested(filters[0] * 3 + filters[1], filters[0], filters[0])
        self.conv1_3 = conv_block_nested(filters[1] * 3 + filters[2], filters[1], filters[1])
        self.conv2_3 = conv_block_nested(filters[2] * 3 + filters[3], filters[2], filters[2])

        self.conv0_4 = conv_block_nested(filters[0] * 4 + filters[1], filters[0], filters[0])
        self.conv1_4 = conv_block_nested(filters[1] * 4 + filters[2], filters[1], filters[1])

        self.conv0_5 = conv_block_nested(filters[0] * 5 + filters[1], filters[0], filters[0])

        self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)

    def forward(self, x):

        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))

        x5_0 = self.conv5_0(self.pool(x4_0))
        x4_1 = self.conv4_1(torch.cat([x4_0, self.Up(x5_0)], 1))
        x3_2 = self.conv3_2(torch.cat([x3_0, x3_1, self.Up(x4_1)], 1))
        x2_3 = self.conv2_3(torch.cat([x2_0, x2_1, x2_2, self.Up(x3_2)], 1))
        x1_4 = self.conv1_4(torch.cat([x1_0, x1_1, x1_2, x1_3, self.Up(x2_3)], 1))
        x0_5 = self.conv0_5(torch.cat([x0_0, x0_1, x0_2, x0_3, x0_4, self.Up(x1_4)], 1))
        # x0_0부터 x0_4 사이의 값들과 x1_4 의 값을 붙여 conv진행해 x0_5 얻음

        output = self.final(x0_5)
        return output

 

이와 같이 UNet++구조를 구현할 수 있습니다. 컬러 이미지를 가지고 model을 진행할 것이고 결과값도 컬러 이미지로 만들기 때문에 in_ch와 out_ch 둘 다 3으로 정의를 했습니다. UNet++는 dense short cut 구조를 가지고 있기 때문에 각 층에서 전에 존재했던 모든 결과와 바로 아래층에서 받은 입력을 붙여 convolution을 진행합니다. 예를 들어 마지막 x0_5를 구하는 코드를 보면 x0_0에서 x0_4의 값들을 붙이고 x1_4의 값을 붙여서 convolution을 진행해 마지막 final에 들어가는 값을 구합니다.

이제 이 model을 사용해 전에 구현했던 GAN 구조에 적용해 이미지를 생성해보겠습니다.

 

왼쪽과 같이 blur 처리된 이미지를 통해 어느정도 복원된 이미지를 얻은 모습을 볼 수 있습니다. 이와 같이 UNet++ model을 이용하는 모습을 볼 수 있습니다.