728x90
저번에는 UNet++ 논문을 보고 공부를 했습니다. 이번에는 논문 내용을 토대로 UNet++를 구현하고 실행해보겠습니다.
논문에서는 총 4가지의 데이터 셋(cell nuclei, colon polyp, liver, lung nodule)으로 실험을 진행했습니다. 논문을 바탕으로 작성한 UNet++ model을 GAN 구조에서 generator의 network로 사용해보겠습니다.
먼저 UNet++ model을 구현해보겠습니다.
class conv_block_nested(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch):
super(conv_block_nested, self).__init__()
self.activation = nn.LeakyReLU(negative_slope = 0.15, inplace=True)
self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(mid_ch)
self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_ch)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn2(x)
output = self.activation(x)
return output
class Nested_UNet(nn.Module):
def __init__(self, in_ch=3, out_ch=3):
super(Nested_UNet, self).__init__()
n1 = 32
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16, n1 * 32]
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
# 컬러 이미지가 들어가 in_ch = 3, 첫 convolution을 통해 3 -> 32
self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])
self.conv5_0 = conv_block_nested(filters[4], filters[5], filters[5])
# 여기서부터 UNet++에서 dense shortcut 구현
self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])
self.conv4_1 = conv_block_nested(filters[4] + filters[5], filters[4], filters[4])
self.conv0_2 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0])
self.conv1_2 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1])
self.conv2_2 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2])
self.conv3_2 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3])
self.conv0_3 = conv_block_nested(filters[0] * 3 + filters[1], filters[0], filters[0])
self.conv1_3 = conv_block_nested(filters[1] * 3 + filters[2], filters[1], filters[1])
self.conv2_3 = conv_block_nested(filters[2] * 3 + filters[3], filters[2], filters[2])
self.conv0_4 = conv_block_nested(filters[0] * 4 + filters[1], filters[0], filters[0])
self.conv1_4 = conv_block_nested(filters[1] * 4 + filters[2], filters[1], filters[1])
self.conv0_5 = conv_block_nested(filters[0] * 5 + filters[1], filters[0], filters[0])
self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)
def forward(self, x):
x0_0 = self.conv0_0(x)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))
x5_0 = self.conv5_0(self.pool(x4_0))
x4_1 = self.conv4_1(torch.cat([x4_0, self.Up(x5_0)], 1))
x3_2 = self.conv3_2(torch.cat([x3_0, x3_1, self.Up(x4_1)], 1))
x2_3 = self.conv2_3(torch.cat([x2_0, x2_1, x2_2, self.Up(x3_2)], 1))
x1_4 = self.conv1_4(torch.cat([x1_0, x1_1, x1_2, x1_3, self.Up(x2_3)], 1))
x0_5 = self.conv0_5(torch.cat([x0_0, x0_1, x0_2, x0_3, x0_4, self.Up(x1_4)], 1))
# x0_0부터 x0_4 사이의 값들과 x1_4 의 값을 붙여 conv진행해 x0_5 얻음
output = self.final(x0_5)
return output
이와 같이 UNet++구조를 구현할 수 있습니다. 컬러 이미지를 가지고 model을 진행할 것이고 결과값도 컬러 이미지로 만들기 때문에 in_ch와 out_ch 둘 다 3으로 정의를 했습니다. UNet++는 dense short cut 구조를 가지고 있기 때문에 각 층에서 전에 존재했던 모든 결과와 바로 아래층에서 받은 입력을 붙여 convolution을 진행합니다. 예를 들어 마지막 x0_5를 구하는 코드를 보면 x0_0에서 x0_4의 값들을 붙이고 x1_4의 값을 붙여서 convolution을 진행해 마지막 final에 들어가는 값을 구합니다.
이제 이 model을 사용해 전에 구현했던 GAN 구조에 적용해 이미지를 생성해보겠습니다.
왼쪽과 같이 blur 처리된 이미지를 통해 어느정도 복원된 이미지를 얻은 모습을 볼 수 있습니다. 이와 같이 UNet++ model을 이용하는 모습을 볼 수 있습니다.
728x90
'연구실 공부' 카테고리의 다른 글
[논문] Improved Training of Wasserstein GANs(WGAN-gp) (0) | 2022.04.04 |
---|---|
[논문] Image-to-Image Translation with Conditional Adversarial Networks (0) | 2022.03.31 |
[논문] Unet++ : A Nested U-Net Architecture for Medical Image Segmentation (0) | 2022.03.28 |
DenseNet 코드 (0) | 2022.03.24 |
[논문] Densely Connected Convolutional Networks (0) | 2022.03.23 |