[Pytorch] Deep Learning Pytorch 5. 합성곱 신경망Ⅰ
입력층(input layer)¶
- 입력 이미지 데이터가 최초로 거치게 되는 계층
- 이미지는 높이(height), 너비(width), 채널(channel)의 값을 갖는 3차원 데이터
- 컬러 이미지의 경우 채널은 3(RGB), 그레이스케일의 경우 채널은 1
합성곱층(Convolutional layer)¶
- 입력 데이터에서 특성을 추출하는 역할을 수행하는 계층
- 이미지에 대한 특성을 감지하기 위해 커널(kernel)이나 필터를 사용
- 이미지의 모든 영역을 훑으면서 특성을 추출하는데 결과는 특성 맵(feature map)
- RGB의 경우 각 채널에 서로 다른 가중치로 합성곱을 진행해 결과를 더해줌
- 필터가 여러 개의 경우 각 필터에 맞춰 합성곱을 진행
풀링층(Pooling layer)¶
- 합성곱층과 유사하게 특성 맵의 차원을 다운 샘플링하여 연산량을 감소시키고, 주요한 특성 벡터를 추출하여 학습을 효과적으로 진행
- 최대 풀링(max pooling): 대상 영역에서 최댓값을 추출
- 평균 풀링(average pooling): 대상 영역에서 평균을 반환
완전연결층(Fully connected layer)¶
- 합성곱층과 풀링층을 거치면서 차원이 축소된 특성 맵은 최종적으로 완전연결층으로 전달
- 이미지는 3차원 벡터에서 1차원 벡터로 펼쳐짐
출력층(Output layer)¶
- 예를 들어 이미지를 분류하는 문제에서는 출력층에 softmax activate function를 사용해 이미지가 각 레이블에 속할 확률 값을 출력
5.2 합성곱 신경망 맛보기¶
- 아래 예제는 fashion_mnist dataset을 사용할 예정
- 기본 MNIST 데이터셋처럼 열 가지로 분류될 수 있는 28x28 픽셀의 이미지 7만 개로 구성되어 있음
- 0에서 9까지 정수 같은 이미지(운동화, 셔츠 등)의 클래스를 나타내는 레이블 0. T-shirt
- Trouser
- Pullover
- Dress
- Coat
- Sandal
- Shirt
- Sneaker
- Bag
- Ankle Boot
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms # 데이터 전처리를 위해 사용하는 라이브러리
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # GPU 사용하기 위해 작성
print(device)
cuda:0
일반적으로 하나의 GPU를 사용할 때¶
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Net()
model.to(device)
다수의 GPU를 사용할 때¶
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(net)
model.to(device)
nn.DataParallel을 사용할 경우 배치 크기가 알아서 각 GPU로 분배되는 방식
따라서 GPU 수만큼 배치 크기도 늘려줘야 한다.
train_dataset = torchvision.datasets.FashionMNIST("../chap05/data", download=True, transform=
transforms.Compose([transforms.ToTensor()]))
test_dataset = torchvision.datasets.FashionMNIST("../chap05/data", download=True, train=False, transform=
transforms.Compose([transforms.ToTensor()]))
# train, test dataset download
# 첫 번째 파라미터: 데이터를 다운받을 주소
# download: True시 다운받을 주소에 다운로드
# transform: 이미지를 텐서로 변경
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../chap05/data\FashionMNIST\raw\train-images-idx3-ubyte.gz
0it [00:00, ?it/s]
Extracting ../chap05/data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ../chap05/data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../chap05/data\FashionMNIST\raw\train-labels-idx1-ubyte.gz
0it [00:00, ?it/s]
Extracting ../chap05/data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../chap05/data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../chap05/data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz
0it [00:00, ?it/s]
Extracting ../chap05/data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ../chap05/data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../chap05/data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz
0it [00:00, ?it/s]
Extracting ../chap05/data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ../chap05/data\FashionMNIST\raw
Processing...
Done!
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=100)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=100)
# 다운로드받은 데이터를 메모리로 불러오기 위해 DataLoader 사용
# batch_size 만큼 데이터를 배치로 묶는다.
labels_map = {0 : 'T-Shirt', 1 : 'Trouser', 2 : 'Pullover', 3 : 'Dress', 4 : 'Coat', 5 : 'Sandal', 6 : 'Shirt',
7 : 'Sneaker', 8 : 'Bag', 9 : 'Ankle Boot'}
fig = plt.figure(figsize=(8,8)); # 출력할 이미지의 가로세로 길이로 단위는 inch
columns = 4;
rows = 5;
for i in range(1, columns*rows +1): # 20장 출력
img_xy = np.random.randint(len(train_dataset));
img = train_dataset[img_xy][0][0,:,:]
fig.add_subplot(rows, columns, i)
plt.title(labels_map[train_dataset[img_xy][1]])
plt.axis('off')
plt.imshow(img, cmap='gray')
plt.show()
import numpy as np
print(np.random.randint(10)) # 0 ~ 9 사이의 정수 하나 랜덤하게 추출
print(np.random.randint(1, 10)) # 1 ~ 9 사이의 정수 하나 랜덤하게 추출
print(np.random.rand(8), end = '\n\n') # 0 ~ 1 사이의 정규표준분포 난수를 행렬(1 x 8)로 표현
print(np.random.rand(4, 2), end = '\n\n') # 0 ~ 1 사이의 정규표준분포 난수를 행렬(4 x 2)로 표현
print(np.random.randn(8), end = '\n\n') # 평균이 0이고 표준편차가 1인 가우시안 정규분포 난수를 표현
print(np.random.randn(4, 2), end = '\n\n') # 평균이 0이고 표준편차가 1인 가우시안 정규분포 난수를 표현
0
6
3
[0.31312966 0.6155271 0.57044374 0.76676257 0.78072825 0.51425842
0.49909136 0.56447091]
[[0.92966889 0.2238793 ]
[0.20219045 0.00634415]
[0.79555633 0.82103119]
[0.19145551 0.57623903]]
[-0.74944281 0.7030532 0.28885959 -0.53480008 -1.10437771 0.8932362
0.53113982 -0.9572707 ]
[[ 0.61334347 1.12622491]
[ 0.30370542 0.4576411 ]
[ 0.3587887 0.86464405]
[-0.00266375 -1.31769872]]
class FashionDNN(nn.Module): # ConvNet이 적용되지 않은 network
# class 형태의 모델은 nn.Module을 상속받음
def __init__(self): # 객체가 갖는 속성 값을 초기화하는 역할
super(FashionDNN,self).__init__() # FashionDNN이라는 부모 클래스를 상속
self.fc1 = nn.Linear(in_features=784,out_features=256)
self.drop = nn.Dropout2d(0.25)
# 0.25만큼 비율로 텐서의 값이 0이 되고 아닌 부분은 1/(1-0.25)만틈 곱해 커진다.
self.fc2 = nn.Linear(in_features=256,out_features=128)
self.fc3 = nn.Linear(in_features=128,out_features=10)
def forward(self,input_data):
out = input_data.view(-1, 784)
out = F.relu(self.fc1(out)) # F.relu는 forward() 함수에서 정의
# nn.ReLU()는 __init__() 함수에서 정의
out = self.drop(out)
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
import torch
import torch.nn as nn
inputs = torch.randn(64, 3, 244, 244)
conv = nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, padding = 1)
outputs = conv(inputs)
print(outputs, outputs.shape)
layers = nn.Conv2d(1, 1, 3)
print(layers)
tensor([[[[ 3.6549e-01, -5.2363e-01, 8.5037e-04, ..., 2.0610e-01,
-8.1525e-01, 3.6881e-01],
[ 4.2801e-01, 3.3798e-01, 7.9399e-02, ..., 7.1064e-01,
2.2421e-01, -2.7774e-01],
[-3.0716e-01, 3.0827e-01, -1.6452e-01, ..., 6.4269e-02,
-1.5601e-01, 5.1100e-02],
...,
[-6.0032e-02, -5.8963e-01, 2.9516e-01, ..., -6.1153e-01,
-8.9675e-01, 7.5102e-01],
[ 5.2026e-01, -8.3147e-03, -2.4641e-01, ..., 2.8855e-01,
9.1004e-01, -5.1009e-01],
[-9.0450e-01, 7.9714e-01, -3.9963e-01, ..., 6.1605e-01,
-4.7111e-01, 3.8647e-01]],
[[ 5.4276e-01, -4.5027e-01, 5.1140e-01, ..., 8.9013e-01,
-5.6408e-02, 2.4386e-01],
[ 1.2598e-01, 1.4264e-01, -1.3847e-01, ..., 9.2507e-01,
3.5877e-01, 4.4066e-01],
[-1.3546e-01, 4.7295e-01, -4.8201e-02, ..., -9.3911e-02,
5.6009e-01, 1.0627e+00],
...,
[ 1.4519e-01, 9.8314e-02, 6.2227e-01, ..., -7.5732e-01,
8.0396e-01, 8.6462e-01],
[ 2.4145e-01, 5.9918e-01, -2.2008e-01, ..., 2.5210e+00,
5.0214e-01, 2.1380e-01],
[ 1.4525e-01, 1.9185e-01, 2.3861e-01, ..., 7.8058e-01,
7.6357e-01, 8.7165e-01]],
[[-4.4137e-02, -2.9537e-01, 3.8844e-01, ..., 4.5285e-01,
-6.6914e-01, 3.4504e-02],
[-4.7889e-01, 3.2260e-01, 1.2505e-01, ..., -1.1204e-01,
-2.1481e-01, 3.6118e-01],
[-6.6101e-01, 3.4102e-01, -9.7376e-01, ..., 1.1235e-01,
-9.1090e-02, -4.8914e-02],
...,
[-2.5412e-01, 4.0230e-01, 1.3744e-01, ..., 1.2333e+00,
3.2976e-01, -1.1693e+00],
[-4.6101e-01, -3.3657e-01, -7.0599e-01, ..., 2.0011e-01,
-6.2689e-01, 7.8925e-01],
[ 1.4875e-02, -2.5096e-01, 2.5619e-01, ..., -2.8200e-03,
-7.6702e-03, -9.5691e-03]],
...,
[[ 5.6169e-01, -3.3082e-02, -2.9675e-01, ..., 1.4712e-02,
-1.8778e-01, 4.0735e-02],
[-2.8779e-01, -1.5079e-02, -6.7109e-01, ..., -1.3469e+00,
3.9111e-01, 3.4822e-01],
[-7.6226e-01, -8.6499e-01, -4.1686e-02, ..., -7.4125e-01,
-2.8728e-01, 9.1076e-01],
...,
[-2.3924e-01, -7.4153e-01, 3.7122e-01, ..., -5.3935e-01,
3.3509e-01, -1.6254e-01],
[-8.8274e-01, 1.4723e+00, -3.5078e-01, ..., 8.1683e-01,
-8.5341e-01, 3.9803e-01],
[ 3.2321e-02, -6.5774e-01, 5.7566e-01, ..., -4.9863e-01,
-6.0888e-01, -2.3758e-01]],
[[ 1.0531e-01, -3.5060e-01, 4.8799e-01, ..., 4.5913e-01,
-6.8087e-01, -1.8015e-01],
[ 3.3091e-01, -3.4230e-01, 5.9080e-01, ..., 3.5623e-01,
-3.5215e-01, -8.5195e-02],
[-2.7094e-01, 2.7672e-01, -4.5509e-01, ..., 9.0156e-01,
-4.2726e-01, 3.5877e-02],
...,
[ 1.8515e-02, -3.2340e-01, 1.0541e-01, ..., 3.1247e-01,
9.2832e-01, -5.6736e-01],
[ 3.2647e-02, 3.5122e-01, -1.2725e-01, ..., -9.2986e-01,
-6.7292e-01, 3.9315e-02],
[ 1.1094e-01, 1.1598e-01, 3.9494e-01, ..., 6.3750e-01,
-1.0262e-02, -4.0328e-01]],
[[-6.6747e-01, -2.2947e-01, -5.1582e-01, ..., -1.5259e-01,
1.5033e-01, 3.0964e-02],
[ 8.6778e-02, -1.4255e-01, 1.5110e-01, ..., -5.9378e-01,
3.5448e-01, 3.1964e-02],
[-1.3596e-01, 1.4194e-01, 5.9306e-01, ..., -1.3774e-01,
3.9639e-01, -5.4116e-01],
...,
[ 8.8294e-02, -4.1394e-01, -2.2959e-01, ..., 1.0646e-01,
1.0521e+00, 2.9991e-01],
[ 1.8262e-01, -2.8846e-01, -3.9359e-01, ..., -8.3559e-01,
-1.1663e-01, -1.0088e-01],
[-6.3681e-01, 3.5560e-01, -6.7130e-01, ..., -1.6053e-01,
-1.4944e-01, 4.0898e-01]]],
[[[-2.2600e-01, -5.8367e-01, -3.5108e-01, ..., 1.1711e+00,
9.7700e-01, -3.4721e-01],
[ 2.1500e-01, 2.9016e-01, 8.7320e-01, ..., -1.2949e+00,
-6.0098e-01, 7.8336e-01],
[ 3.9669e-01, 1.6946e-01, -5.8263e-01, ..., 8.9068e-01,
1.4487e-01, -2.6316e-01],
...,
[ 3.1374e-01, -9.9551e-01, 5.0427e-01, ..., -8.5003e-01,
-7.5878e-01, 3.9216e-01],
[-2.2090e-01, 4.8723e-01, -1.9076e-01, ..., 1.0538e+00,
-6.7026e-01, -2.1250e-01],
[ 1.3094e-01, -6.9253e-01, 3.0784e-02, ..., 1.4255e-01,
-1.8485e-01, 8.7682e-02]],
[[ 2.7806e-01, 1.4926e+00, -2.6596e-01, ..., 1.2379e+00,
7.7267e-01, -6.0523e-02],
[ 4.5270e-01, 1.1629e+00, 3.2993e-01, ..., -2.6717e-01,
-1.3307e-01, 1.0815e+00],
[ 2.0259e-01, -4.8138e-01, -2.1517e-01, ..., -2.0568e-01,
2.6034e-01, -3.0727e-01],
...,
[ 4.3083e-01, -3.6818e-01, -4.9298e-01, ..., -1.0010e+00,
7.9276e-02, -3.5663e-02],
[ 2.0758e-01, 5.4107e-01, -1.7922e-01, ..., 1.8200e+00,
-4.1985e-01, 2.0261e-01],
[ 6.3598e-02, 1.0379e-01, 1.1531e+00, ..., -1.2838e-01,
6.5593e-01, 2.3857e-01]],
[[ 1.9649e-02, 5.3960e-01, -4.9979e-02, ..., -2.1728e-01,
-1.6435e-03, 4.1753e-01],
[-7.4838e-01, -9.9514e-01, 3.5350e-01, ..., 3.9593e-01,
-4.1919e-01, -3.9346e-01],
[-2.7400e-01, 1.1713e-01, 6.1710e-01, ..., 3.8929e-01,
-4.1564e-02, 4.4688e-01],
...,
[ 1.8358e-01, -3.3259e-01, -5.6559e-01, ..., -1.2098e+00,
3.5400e-01, -7.6105e-01],
[-4.7107e-01, -9.5111e-01, -1.4583e-01, ..., -5.4820e-02,
-2.0386e-01, 3.6756e-01],
[-1.4488e-01, 6.3644e-01, 2.1750e-02, ..., 4.5403e-01,
1.1938e-01, -2.9251e-01]],
...,
[[-3.7757e-01, -1.2258e-01, -1.0363e+00, ..., -2.4290e-01,
-9.0076e-02, 8.0416e-01],
[ 1.0461e+00, -1.0124e+00, -6.6369e-01, ..., -7.5430e-01,
-1.8123e+00, -4.9736e-01],
[ 3.9978e-01, -1.1391e+00, 1.1127e-02, ..., 1.6808e-02,
-1.8149e-01, -2.5316e-01],
...,
[-3.9619e-01, -1.0368e-01, -5.9450e-01, ..., -2.3389e-01,
-5.5913e-01, 1.5113e-01],
[-2.9828e-01, -1.6139e-01, 1.1431e-01, ..., 2.4094e-01,
-2.6426e-01, -2.9497e-01],
[-1.8854e-01, -4.0388e-01, -1.3510e-01, ..., -9.5871e-01,
8.2675e-01, -7.8001e-02]],
[[-1.8854e-01, 6.4187e-01, -7.3299e-02, ..., -5.1570e-02,
1.0679e-01, 1.2001e-01],
[ 6.7716e-01, -1.1548e+00, -1.8361e-01, ..., 2.1382e-01,
4.0783e-01, -6.2766e-02],
[-5.7475e-01, 1.6158e-01, 1.0359e+00, ..., -1.1792e+00,
-3.6638e-01, 1.0459e-01],
...,
[ 7.2411e-01, -2.3425e-01, 3.6339e-01, ..., -9.6439e-01,
1.5909e-01, -4.0518e-01],
[-2.1034e-01, -1.2396e+00, -3.9739e-01, ..., 3.5831e-02,
-5.6542e-01, 4.1599e-01],
[ 2.1236e-01, 1.9375e-01, 3.0911e-01, ..., 3.1032e-01,
-3.5209e-01, 7.8740e-03]],
[[-1.0947e+00, -1.8300e-01, 1.1816e+00, ..., -4.3524e-01,
-3.6382e-01, -5.7171e-01],
[-6.0343e-01, 3.0602e-01, -4.4001e-01, ..., -5.1629e-02,
1.0572e+00, 6.9078e-01],
[-1.6459e+00, -8.7418e-02, 7.2200e-01, ..., -1.0336e+00,
-6.0698e-01, -6.6297e-02],
...,
[-2.5832e-01, -1.5786e-01, 4.9828e-01, ..., 4.1631e-01,
4.1641e-01, -3.2169e-01],
[ 4.7528e-02, 4.3334e-02, -5.0912e-01, ..., 4.4493e-02,
5.1082e-02, 2.3162e-01],
[-7.5049e-01, 3.4273e-01, 1.1934e-02, ..., 4.0409e-01,
1.1167e-01, -1.3937e-01]]],
[[[ 7.5386e-01, -8.1574e-01, 9.3474e-01, ..., -4.7120e-01,
6.9123e-02, 4.3327e-01],
[-9.4539e-01, 6.9518e-02, 6.0282e-03, ..., -3.9726e-01,
9.9328e-01, -4.1662e-01],
[ 4.3125e-01, -2.1990e-01, -3.7317e-01, ..., 4.0007e-01,
-1.0055e+00, 1.3260e+00],
...,
[ 2.4094e-01, 5.9493e-01, 7.4656e-01, ..., -5.4235e-01,
-6.8779e-01, -9.4870e-02],
[ 1.1759e+00, 2.9144e-01, 9.1766e-03, ..., 6.5067e-02,
1.0077e+00, -6.7110e-01],
[-6.3793e-01, -1.2079e-01, 2.3137e-01, ..., 6.8644e-01,
-7.0669e-02, 1.0702e-01]],
[[ 1.0340e+00, 1.9105e-01, 3.0749e-01, ..., -2.3533e-01,
1.0467e-01, 2.8612e-01],
[-5.4119e-01, 5.1947e-01, 1.1760e+00, ..., 6.7689e-01,
2.7099e-01, -5.7409e-01],
[ 3.7125e-01, 6.2268e-01, -9.3538e-01, ..., 3.0586e-01,
-1.2582e-01, 2.9805e-01],
...,
[ 5.1739e-01, -1.2390e-01, 1.3313e-01, ..., 1.1132e+00,
1.2525e+00, 3.3574e-01],
[-2.2680e-01, -1.5999e-01, 1.3324e+00, ..., 2.3239e-01,
8.8118e-01, 5.4286e-01],
[-8.6178e-02, 1.4490e+00, -3.1079e-01, ..., 3.0327e-01,
2.1153e-01, 1.3696e-01]],
[[ 6.3734e-01, -3.3969e-01, 3.9968e-02, ..., -4.0088e-01,
-7.0133e-01, -5.0167e-01],
[-1.8952e-01, -1.1876e+00, 9.7211e-01, ..., 3.6845e-01,
-3.6739e-02, 6.9974e-01],
[-7.6615e-02, -3.1680e-01, -4.9432e-01, ..., -1.1690e+00,
-5.0400e-01, -1.3052e-01],
...,
[-2.9638e-01, 2.9772e-01, -2.6469e-01, ..., -6.3788e-02,
4.4452e-01, -1.0844e+00],
[-3.4544e-01, 1.0430e+00, 6.0628e-01, ..., -9.5766e-01,
-2.3104e-01, -7.5102e-02],
[ 2.2480e-01, -3.8911e-02, -1.2580e+00, ..., 9.8024e-02,
8.8596e-01, -6.3549e-01]],
...,
[[-7.3262e-02, -9.3611e-02, -5.9564e-01, ..., 2.7764e-01,
-1.4616e-01, -3.1557e-01],
[ 1.5343e-02, -9.9659e-02, -5.5904e-01, ..., 1.2825e+00,
-2.7224e-01, 4.2135e-01],
[-5.5198e-01, 2.1461e-01, 8.3376e-03, ..., -5.0202e-01,
1.0701e+00, -8.6593e-01],
...,
[ 7.1477e-01, 2.1116e-02, 3.2111e-01, ..., -7.6400e-01,
2.1008e-02, -6.5101e-01],
[-1.7675e-01, 9.2379e-01, 2.6295e-01, ..., -2.1694e-01,
-4.1010e-01, 5.7985e-02],
[ 6.1115e-02, 2.4568e-01, -3.5558e-02, ..., -2.1429e-01,
-8.0749e-01, -5.7468e-01]],
[[ 7.3819e-01, -8.4027e-01, -1.5376e-01, ..., -5.7554e-01,
-3.5341e-01, -9.7349e-02],
[ 4.7771e-01, -3.0529e-01, 7.5185e-01, ..., -5.1382e-01,
-1.6467e-01, 2.7416e-01],
[ 9.6753e-02, -5.9742e-01, 6.5344e-01, ..., -3.6832e-02,
-5.4061e-01, 8.4194e-02],
...,
[-4.4030e-01, -6.0548e-01, -6.9743e-01, ..., 1.9253e-01,
6.0378e-01, -9.7295e-01],
[-5.9419e-01, 8.2188e-01, 6.1309e-01, ..., -8.9309e-01,
6.6205e-01, 2.3262e-01],
[ 9.8363e-01, 2.5841e-02, -1.0424e-01, ..., 5.1668e-01,
5.6010e-01, 2.4743e-01]],
[[-1.2914e-01, -6.5234e-01, 2.9072e-01, ..., -7.7514e-02,
-4.7735e-02, -3.7225e-02],
[ 2.0092e-02, -2.4256e-01, -5.3399e-01, ..., -9.2019e-01,
-3.3970e-01, -2.0673e-01],
[-2.5668e-01, -2.8117e-01, 1.7851e-01, ..., 1.0705e+00,
1.1425e-01, -1.0710e-01],
...,
[-4.3121e-01, -8.3084e-01, -6.5409e-01, ..., 2.2849e-01,
2.1165e-01, 2.5805e-01],
[-5.4191e-01, -1.1104e+00, -1.1479e+00, ..., -6.0833e-01,
-8.0777e-01, -6.9835e-02],
[ 6.2193e-01, -1.5904e-01, 5.8600e-02, ..., -1.2512e+00,
-1.5262e-01, 5.8542e-01]]],
...,
[[[-5.0544e-01, 8.5780e-01, 5.2637e-01, ..., -2.4023e-01,
1.0301e-02, 6.9209e-02],
[ 3.2893e-01, -2.3867e-01, -2.5798e-01, ..., 2.6131e-01,
-3.7587e-01, -1.4080e-01],
[ 3.9996e-03, -7.5199e-01, 9.5608e-01, ..., -3.5380e-01,
4.1908e-01, 3.4709e-01],
...,
[-7.1191e-02, 3.2287e-01, 1.8524e-01, ..., 4.1100e-01,
-4.7726e-01, -2.3294e-01],
[-8.7994e-01, -3.8979e-01, 1.7204e-01, ..., -3.3232e-01,
4.8341e-01, -4.1078e-01],
[ 5.1747e-01, -7.1313e-01, -5.1332e-02, ..., -3.5998e-01,
-5.0132e-01, 8.9805e-02]],
[[-4.3424e-02, 1.2194e+00, 1.6390e-01, ..., -5.9446e-01,
4.7078e-01, 7.1073e-01],
[ 1.0613e+00, -9.2875e-01, 3.1401e-01, ..., 8.6006e-01,
6.6498e-01, 5.6883e-01],
[ 5.8961e-01, -4.3127e-01, 1.5632e+00, ..., 7.6575e-02,
1.1099e+00, 5.9207e-01],
...,
[-4.8157e-01, 1.7722e-01, -1.7361e-01, ..., -4.4342e-01,
7.4260e-01, 6.2679e-01],
[ 3.6440e-01, 3.7942e-01, -6.0920e-01, ..., -7.8349e-01,
-4.2637e-01, 2.0362e-01],
[ 6.1908e-01, -3.0486e-01, 9.1690e-01, ..., 7.1675e-01,
6.2479e-01, -1.4546e-01]],
[[-2.1094e-01, -4.0621e-01, -6.6550e-01, ..., -1.8968e-01,
-2.9693e-01, -3.5376e-01],
[ 1.6110e-01, 2.4310e-01, 6.1427e-01, ..., -2.5431e-01,
6.4959e-02, 1.8066e-01],
[-3.7013e-01, -2.7693e-01, 3.2180e-01, ..., -2.8144e-01,
2.5135e-01, 1.3701e-01],
...,
[-5.0048e-01, 1.5550e-01, -7.4852e-02, ..., -3.3619e-01,
-7.8651e-01, 1.3829e-01],
[-7.4928e-02, -3.6513e-01, 1.2009e-01, ..., 5.8893e-01,
-6.3800e-01, 1.8360e-02],
[-1.2225e-01, -3.9033e-01, 1.1073e+00, ..., -2.5415e-01,
-3.3113e-01, -3.3322e-01]],
...,
[[ 3.6258e-01, 4.4693e-02, 1.4640e-01, ..., -2.7772e-01,
1.2350e-01, -4.3048e-02],
[ 7.8671e-02, -4.9774e-01, -4.8004e-02, ..., 6.2193e-01,
1.7298e-01, -8.2164e-01],
[ 1.5231e-01, -3.7577e-01, 1.5518e-01, ..., 4.9845e-01,
-5.9646e-01, -1.0696e+00],
...,
[ 1.9583e-01, 2.2629e-01, 5.8802e-01, ..., -9.7233e-01,
2.2300e-02, 2.8873e-01],
[ 1.6864e-01, -1.0957e+00, -4.8231e-02, ..., -4.8367e-01,
-7.0812e-01, -6.1247e-01],
[-6.3365e-02, 3.2716e-01, 3.9675e-01, ..., 5.7402e-01,
-1.0558e-01, -9.0506e-01]],
[[-1.6848e-01, -2.4022e-02, -7.6116e-01, ..., -6.1588e-01,
-1.4329e-01, -1.0255e-01],
[-8.8470e-02, 9.3889e-02, -1.1231e-01, ..., 7.8303e-02,
1.4677e-02, 1.2018e-01],
[ 1.4265e+00, -1.2373e+00, 2.5977e-02, ..., 3.5222e-01,
-2.6994e-01, -4.7857e-01],
...,
[-4.0872e-01, -2.1461e-01, 8.4131e-02, ..., 5.9790e-01,
-9.4338e-02, -1.3462e-01],
[ 2.6593e-01, -1.3361e+00, -3.9470e-01, ..., -4.5141e-01,
4.1092e-01, -1.8364e-01],
[-3.9806e-01, -7.4410e-01, -4.7524e-01, ..., -2.4597e-01,
-1.4656e-01, 5.4989e-02]],
[[-3.6600e-01, -3.6642e-01, -2.8130e-01, ..., -2.0138e-01,
-2.4125e-01, -8.4432e-02],
[-1.2094e+00, -4.5104e-01, -2.7191e-01, ..., -6.2079e-01,
-9.4042e-01, 3.6020e-01],
[ 1.1239e-01, -1.8805e-01, -4.6376e-02, ..., -4.7976e-01,
-9.0022e-01, 3.0977e-01],
...,
[-6.2738e-01, 5.5063e-01, -3.6498e-02, ..., 1.1481e+00,
-9.5921e-02, -2.5636e-01],
[ 2.0103e-01, 6.5315e-01, -7.5152e-02, ..., -1.8207e-01,
-1.2782e-01, 2.9979e-01],
[-2.2556e-01, -3.4191e-01, -2.1745e-02, ..., 3.0694e-01,
-5.9686e-01, 2.5398e-01]]],
[[[-1.4984e-01, -5.7466e-01, 2.7205e-01, ..., -1.0618e-01,
-7.8802e-01, -4.1186e-01],
[ 2.4737e-01, 8.0446e-01, -8.7583e-01, ..., 2.3019e-01,
2.4694e-01, -3.3110e-01],
[-8.7388e-01, 3.1071e-01, 1.0340e+00, ..., -8.9710e-01,
5.8546e-01, -1.3166e-01],
...,
[ 2.0251e-01, -1.1783e-01, 3.9977e-01, ..., -1.3278e-01,
-7.1320e-01, -2.1927e-01],
[ 6.2523e-01, 1.7511e-01, 1.1378e+00, ..., -5.5886e-02,
8.9601e-01, -7.7158e-01],
[-3.4531e-01, 4.4479e-01, -9.6884e-01, ..., -2.0583e-01,
-1.7329e-01, 3.1278e-01]],
[[-1.1992e-01, -2.3790e-01, -3.7117e-02, ..., -2.2331e-01,
2.7203e-01, 2.6221e-01],
[ 5.1685e-01, 7.1050e-01, 1.2731e-01, ..., 2.3897e-01,
-2.6551e-01, 2.5567e-02],
[ 8.6539e-02, 1.0519e-01, 9.1765e-01, ..., 1.6167e-02,
9.3387e-01, -6.2523e-01],
...,
[-2.7610e-01, 8.3931e-01, 6.7828e-01, ..., 3.4719e-01,
-3.5855e-02, 2.0133e-01],
[ 3.4829e-01, 9.3117e-01, -1.2565e+00, ..., 2.0558e-01,
-9.5823e-01, 4.4334e-01],
[-2.7640e-01, 3.9078e-01, 1.7737e-01, ..., 1.1970e-01,
3.6605e-01, 1.9830e-01]],
[[-6.1452e-02, -5.9664e-02, -5.3661e-01, ..., -9.0277e-02,
4.8025e-01, 1.9923e-01],
[ 2.6368e-01, 2.7136e-01, 7.8029e-01, ..., 9.8577e-02,
-1.1563e+00, -9.0731e-01],
[ 9.2607e-03, -7.5693e-01, -7.0630e-01, ..., -6.4593e-01,
1.9295e-01, -1.6017e-01],
...,
[-4.1584e-02, -3.5872e-01, -9.5990e-02, ..., -2.1965e-01,
8.8581e-02, 5.2954e-01],
[-1.0846e-01, 1.9707e-01, 2.0113e-01, ..., 1.4937e-01,
-2.2008e-01, -5.4682e-01],
[ 5.8928e-02, 1.3411e-01, 1.0322e-01, ..., -5.4324e-01,
4.0180e-01, -8.5446e-01]],
...,
[[ 5.5914e-01, -2.0381e-01, -1.1471e+00, ..., 1.4495e-02,
-7.3426e-01, -6.1311e-01],
[-2.8222e-01, 5.4590e-01, -1.3230e-01, ..., -4.7191e-02,
-1.2168e-02, -3.0494e-01],
[-8.6871e-01, 1.2863e-01, 5.2482e-01, ..., -4.2578e-01,
-8.4123e-01, -6.3531e-01],
...,
[ 5.8155e-01, 1.1434e+00, -1.0258e+00, ..., -1.3789e-01,
-7.9806e-01, -1.3070e-01],
[ 9.9619e-01, 8.8556e-01, -1.8439e+00, ..., 1.0001e-01,
6.4343e-02, 6.1679e-01],
[ 4.5124e-01, -5.0193e-01, 6.0574e-02, ..., -6.7932e-01,
2.7143e-01, 7.8264e-02]],
[[-1.0228e-01, -5.0796e-01, -4.1826e-01, ..., 2.5499e-01,
9.3342e-01, -2.2726e-03],
[ 3.8552e-01, 4.1864e-01, -4.6249e-01, ..., 3.6133e-01,
-2.6212e-01, 1.0776e-01],
[-3.6214e-01, 3.4448e-01, 5.8692e-01, ..., -9.7659e-01,
2.9236e-01, 3.3784e-01],
...,
[-9.2251e-01, -1.0040e-01, 5.9268e-02, ..., 4.5869e-01,
-3.4235e-01, 3.1652e-02],
[ 3.1298e-01, -1.0072e+00, 4.0321e-01, ..., -6.8778e-01,
1.4015e-01, -5.7920e-01],
[ 6.9415e-01, -3.4184e-01, 9.3346e-02, ..., -3.7112e-01,
7.0463e-01, 1.5736e-01]],
[[-2.7359e-01, -1.0287e+00, 6.2889e-01, ..., -4.8151e-01,
-4.3184e-01, 4.6289e-01],
[ 4.0094e-01, -5.8931e-01, -1.2797e+00, ..., 8.0863e-01,
4.3566e-01, -2.4896e-01],
[ 3.8845e-01, 3.3383e-01, -5.8042e-01, ..., -3.6462e-01,
-6.5923e-01, 3.5620e-01],
...,
[-5.6518e-01, -1.1101e+00, 4.4280e-01, ..., 5.7189e-01,
-1.1908e-01, -1.0612e-01],
[-7.5290e-01, -1.5972e+00, -8.8928e-02, ..., -4.4015e-01,
3.3357e-01, -7.7590e-01],
[-1.7621e-01, 5.3046e-01, -5.6683e-02, ..., 9.4456e-01,
-7.7946e-01, -2.1709e-01]]],
[[[ 1.1216e-01, 1.2376e-02, 4.1331e-01, ..., 6.0966e-01,
2.4719e-01, 1.4812e-01],
[ 6.4899e-01, 7.1311e-01, -3.3906e-01, ..., 2.9330e-01,
-1.1954e+00, -2.6349e-02],
[-4.2781e-01, 3.0337e-01, 3.4546e-01, ..., -3.3154e-01,
2.5730e-04, -1.7982e-01],
...,
[ 1.5366e-01, -1.1368e+00, -1.4315e-01, ..., -6.7854e-01,
-3.5986e-01, 3.0314e-02],
[-1.8917e-01, 9.7544e-02, -2.1884e-01, ..., -6.4436e-01,
-4.2350e-02, -7.8904e-01],
[-5.4305e-02, 5.2305e-02, -1.3009e+00, ..., -1.2446e-01,
1.7525e-01, -4.0422e-02]],
[[ 5.2499e-01, 2.6599e-01, 6.4231e-01, ..., 5.5251e-01,
8.3496e-01, 3.1528e-01],
[ 4.3122e-01, 6.6263e-01, 3.7475e-01, ..., 4.8276e-01,
2.0361e-01, 6.3895e-01],
[ 4.2711e-01, 3.8799e-01, 6.7873e-01, ..., 3.7020e-01,
4.5154e-01, -7.3281e-02],
...,
[ 2.1348e-01, -1.3551e+00, 1.0151e-01, ..., -1.6741e-01,
-4.2509e-01, 2.6236e-02],
[-1.4784e-01, -1.9691e-01, -3.1651e-01, ..., -4.4129e-01,
-3.5028e-01, 4.4362e-02],
[ 3.1433e-01, -3.0982e-01, -7.0571e-02, ..., 3.6098e-01,
2.2128e-01, 8.0780e-02]],
[[-3.7500e-02, -6.7766e-01, 2.2229e-01, ..., 3.8954e-01,
-1.1402e-01, 2.3104e-01],
[-2.9661e-01, 2.2959e-01, 7.7930e-01, ..., 1.0055e-01,
-3.3114e-01, 1.9831e-01],
[ 9.7502e-02, 2.0613e-01, -1.9237e-01, ..., -5.6089e-01,
-1.1081e+00, -8.1096e-02],
...,
[ 3.3913e-01, 2.0574e-01, 1.8545e-01, ..., -2.4386e-01,
-8.4043e-01, -6.5657e-01],
[-3.3809e-01, -5.1951e-01, -3.9593e-01, ..., 2.5898e-01,
-7.5209e-01, 7.7463e-02],
[ 4.5722e-02, 3.6646e-01, -7.2391e-01, ..., -4.6770e-03,
-5.0836e-01, -4.4301e-01]],
...,
[[ 7.5651e-02, -3.0160e-01, -7.5231e-02, ..., -5.2350e-02,
2.4744e-01, 2.1424e-01],
[-1.7517e-01, -8.7234e-02, -2.0855e-01, ..., 3.4877e-01,
-4.0596e-01, -1.5396e-01],
[-4.7071e-01, -5.9674e-01, -8.5921e-01, ..., 1.6319e-01,
-1.4386e+00, -1.4108e-01],
...,
[-1.1749e-01, -1.4653e-01, -4.6916e-03, ..., -5.1826e-01,
-5.0656e-01, -7.0923e-01],
[ 1.7672e-01, -1.7930e-01, 1.0270e-02, ..., -2.0776e-01,
-2.2900e-01, 2.5714e-02],
[-8.6456e-01, 1.9280e-01, 6.5730e-01, ..., -1.1094e-01,
1.7192e-02, -5.0143e-01]],
[[ 1.0594e-01, -8.0439e-01, -9.6410e-03, ..., 6.3742e-02,
-3.1948e-01, 1.3424e-01],
[-1.8579e-02, 6.3207e-01, 6.9292e-01, ..., 1.2559e+00,
3.9681e-02, 3.6331e-01],
[-5.4891e-02, 3.6958e-01, 3.9475e-01, ..., 8.3305e-01,
-9.5532e-01, -4.1831e-01],
...,
[-1.3267e-01, 6.5018e-01, 3.2770e-01, ..., 4.9844e-01,
2.3263e-01, -5.8668e-01],
[ 1.5022e-01, -1.0821e+00, 1.8750e-01, ..., -7.6496e-01,
3.9547e-01, 1.9304e-01],
[-6.1980e-01, 5.5627e-01, -2.5585e-01, ..., -3.5473e-02,
4.9175e-01, -4.1605e-01]],
[[-3.5709e-01, -5.3607e-01, -6.1129e-01, ..., -6.3815e-01,
-2.6293e-01, -1.1265e-01],
[-6.1185e-01, -4.3454e-01, 1.3337e-01, ..., -7.2634e-01,
1.7828e-01, 3.6939e-01],
[-2.8440e-01, 4.7747e-01, -1.4298e-01, ..., -2.9764e-01,
3.5095e-01, 3.7907e-02],
...,
[-4.4666e-01, 8.3634e-01, 4.9112e-01, ..., 7.8301e-01,
-3.3881e-01, 1.3558e-01],
[ 8.0526e-01, 2.0820e-01, -7.8607e-01, ..., -5.6830e-03,
-7.7251e-01, -4.2119e-02],
[ 4.0610e-01, -6.5346e-02, -2.4101e-01, ..., -8.6582e-04,
-2.4418e-01, 2.6234e-01]]]], grad_fn=<MkldnnConvolutionBackward>) torch.Size([64, 64, 244, 244])
Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1))
import torch.nn.functional as F
inputs = torch.randn(64, 3, 244, 244)
weight = torch.randn(64, 3, 3, 3)
bias = torch.randn(64)
outputs = F.conv2d(inputs, weight, bias, padding = 1)
print(outputs, outputs.shape)
tensor([[[[-5.6240e+00, -1.8608e-01, -3.4725e+00, ..., 6.7715e+00,
-2.9745e+00, -6.3199e+00],
[-1.0912e+01, -5.9477e+00, -1.1181e+01, ..., 9.3399e-01,
-1.8280e+00, -1.8374e-01],
[-2.5608e-01, 2.0488e+00, 3.4480e+00, ..., -4.0103e+00,
6.9759e+00, -2.6398e+00],
...,
[ 3.4864e+00, 5.5696e+00, 2.0189e+00, ..., -2.7002e-01,
-2.7228e+00, -4.5392e+00],
[-4.5231e+00, 2.5074e+00, -2.8776e+00, ..., -2.3811e+00,
3.4481e+00, 5.5901e+00],
[-1.8204e+00, 2.4198e+00, -1.6878e+00, ..., -2.4550e+00,
1.1623e+00, -2.4619e-01]],
[[ 1.5697e+00, 3.7450e+00, 4.0305e-03, ..., 9.1757e-01,
6.4131e-01, -2.4048e+00],
[ 3.6161e+00, 8.8489e+00, -1.1538e+00, ..., -6.6890e+00,
2.6071e+00, 3.6102e-01],
[-1.1524e+00, 1.0449e+01, -1.8557e-01, ..., 5.9676e-01,
3.5772e-01, 3.7720e+00],
...,
[ 2.0698e+00, -2.3579e+00, 4.5937e+00, ..., 5.2502e+00,
-6.9241e+00, 1.2336e+00],
[ 5.5243e-01, -5.3092e+00, -3.7190e+00, ..., -5.3445e+00,
1.0226e+01, -3.1384e+00],
[-5.1358e+00, -1.7396e+00, -6.1192e+00, ..., -7.3403e+00,
2.2614e+00, 9.0756e-02]],
[[ 1.9230e+00, -3.5130e+00, -9.9671e-01, ..., 6.2927e+00,
2.4662e-01, -3.1542e+00],
[ 2.5990e+00, -1.5372e+00, 1.3554e+00, ..., -5.3421e+00,
-1.1773e+01, -8.0674e-01],
[ 5.6407e+00, -1.1010e+01, -1.0653e+01, ..., 6.4382e+00,
6.8799e+00, 9.2794e+00],
...,
[ 2.8173e+00, 2.4895e+00, -5.5019e+00, ..., -8.7297e+00,
9.8538e-02, -3.8595e-01],
[ 8.6118e-01, 4.3890e-01, 4.3416e+00, ..., 8.9869e-01,
-1.1926e+01, -1.8680e+00],
[ 2.6686e+00, -6.1164e-01, 6.0339e+00, ..., 7.9320e+00,
4.2218e+00, 1.0939e+00]],
...,
[[ 7.5288e+00, 9.3011e+00, 1.4576e+00, ..., -3.3953e+00,
-4.5890e+00, 6.5937e+00],
[ 8.5254e+00, 9.9165e+00, 1.0203e+01, ..., 4.8507e+00,
2.7234e+00, -6.5033e+00],
[ 5.2573e+00, 1.2804e+01, -7.3197e+00, ..., -2.4170e+00,
-4.3160e+00, 2.3867e+00],
...,
[-4.0748e+00, -2.5004e+00, 4.1164e-01, ..., -2.4980e+00,
-1.5702e+01, 7.8769e+00],
[ 9.5276e-01, -8.5480e+00, -6.8960e+00, ..., 6.4627e+00,
-6.7753e-01, -1.3322e+01],
[ 8.3694e-01, -1.8424e+00, -5.8660e+00, ..., -2.5498e+00,
5.7264e+00, -3.7971e+00]],
[[ 3.1417e+00, 3.5253e+00, -1.3571e+00, ..., 9.8244e-01,
-2.3254e+00, 5.1003e-01],
[ 4.2210e+00, 1.4066e+01, 2.8249e+00, ..., 1.3078e+00,
-2.9496e+00, 4.8150e+00],
[-2.0716e+00, 1.7889e+00, -3.1566e+00, ..., 3.2937e+00,
-2.6417e+00, -7.4770e-01],
...,
[-1.2583e+00, -8.0223e+00, -1.0577e+00, ..., 2.5193e+00,
-9.9420e+00, 2.5510e+00],
[ 8.5328e-01, -2.9389e+00, -1.1645e+00, ..., 3.7632e+00,
1.6784e+00, -5.2495e+00],
[ 4.2759e-01, 3.3602e+00, -3.5179e+00, ..., -2.1891e+00,
7.0494e+00, -1.1137e-01]],
[[ 1.8359e-01, -6.0223e-01, -1.6471e+00, ..., 2.8840e+00,
4.1253e+00, 3.8031e+00],
[ 4.4760e+00, 4.7014e+00, 8.1031e+00, ..., 2.3049e+00,
2.7668e+00, -4.1373e+00],
[-2.0746e+00, 4.4603e+00, 3.4872e+00, ..., -6.2810e+00,
-3.0871e+00, 1.0628e+01],
...,
[-6.3680e+00, 3.9052e+00, 2.7070e+00, ..., 5.1652e+00,
-8.4580e+00, -2.0504e-01],
[ 1.8676e+00, -2.8517e+00, -5.2832e-01, ..., 1.0434e+00,
-6.3211e-01, -1.8058e+00],
[-1.7788e+00, -1.1290e+00, -2.6139e+00, ..., -2.2110e+00,
-2.2202e+00, 2.0146e+00]]],
[[[-4.6849e-03, 1.0202e+01, -1.8493e+00, ..., 3.3392e-01,
6.9326e-01, -3.4069e+00],
[-7.9633e+00, -1.6482e+00, -4.3074e+00, ..., -3.1047e+00,
-3.1231e+00, -1.9731e+00],
[-5.6108e+00, -3.7644e+00, 5.1720e+00, ..., 2.1651e+00,
6.9588e+00, -1.9044e-02],
...,
[-3.7018e+00, 4.3771e+00, -7.6444e+00, ..., -1.4408e+00,
1.7378e+00, 5.9945e-01],
[-7.5424e+00, -8.6236e+00, -3.3867e+00, ..., 7.2344e+00,
3.4418e-02, -6.2588e+00],
[-1.9814e+00, 6.3468e+00, -4.7693e+00, ..., -1.8845e+00,
-4.5620e+00, 2.5915e-01]],
[[-4.4045e+00, 1.4213e+00, 1.7797e+00, ..., 3.8242e+00,
7.1975e-01, 7.3834e+00],
[ 7.7037e+00, -9.7811e-01, -4.6297e+00, ..., -8.1258e-01,
-2.1369e+00, 1.3723e+00],
[ 5.0194e+00, -3.2318e+00, -2.6665e+00, ..., -2.2193e+00,
-1.8950e-01, -8.5786e-01],
...,
[-4.6800e+00, 1.8341e+00, 4.8958e+00, ..., 7.0567e+00,
-8.1680e-01, 6.1891e+00],
[-5.0323e-01, 4.2838e+00, 2.2828e+00, ..., -3.8879e+00,
4.1903e+00, -1.1864e+00],
[ 3.0299e+00, -2.1204e+00, 4.5155e+00, ..., -3.8924e-01,
1.3241e+00, -1.5060e+00]],
[[ 6.8414e+00, 8.5112e-01, 1.7571e+00, ..., -6.6176e+00,
-5.2915e+00, -5.4626e+00],
[ 9.2942e-01, -2.3073e+00, 4.1057e+00, ..., -3.6887e+00,
-3.4355e+00, 2.1354e+00],
[-1.3083e+00, -1.9962e+00, 3.0479e+00, ..., -7.8509e+00,
1.0952e+00, 3.1952e+00],
...,
[ 4.4377e+00, -5.4851e+00, 1.5791e+00, ..., -5.1249e+00,
7.1721e-01, -2.9926e+00],
[ 4.0735e+00, -2.8057e+00, 4.0935e-01, ..., 4.2917e+00,
-4.7598e+00, 4.9633e-01],
[-3.7965e+00, -1.6377e+00, -3.1270e+00, ..., -1.5413e+00,
1.7901e+00, 2.5783e+00]],
...,
[[-1.4484e-01, -7.4939e-01, -1.0039e+01, ..., 4.4375e+00,
-1.9854e-01, 7.8964e+00],
[ 1.2479e+00, -1.0293e+00, -5.6065e+00, ..., -8.3735e+00,
-5.2661e+00, -1.9931e-01],
[ 3.5657e+00, 1.9208e+00, 6.8390e+00, ..., 1.8567e+00,
3.1011e+00, 2.8393e+00],
...,
[ 5.3165e+00, 5.1474e-01, -1.8639e-01, ..., -5.7104e-01,
4.0933e-01, 3.7433e+00],
[-1.5383e+00, 6.9130e+00, 8.0474e+00, ..., -1.6940e+00,
-6.0404e+00, 5.0017e+00],
[ 4.1321e+00, -1.7884e+00, 1.4612e+00, ..., -1.2736e+00,
7.7988e+00, 1.5558e-01]],
[[ 1.1736e+00, 1.2587e-01, -6.4687e+00, ..., 5.7237e-01,
-2.8070e-01, 4.6800e+00],
[ 7.1144e-01, 6.2290e+00, -9.7892e+00, ..., -8.2676e+00,
-6.0322e+00, -1.3433e+00],
[-2.8999e+00, 5.8895e+00, 2.6310e+00, ..., -8.7704e-01,
2.8187e+00, -1.8960e+00],
...,
[ 1.5085e+00, 4.6574e-01, 9.0403e-01, ..., 6.2830e+00,
-2.9186e+00, 4.6114e+00],
[-3.2720e+00, 7.4401e+00, 4.5522e-01, ..., -8.0325e+00,
-5.6519e+00, -1.1222e+00],
[-1.9528e+00, -3.2000e+00, 7.9871e-01, ..., 4.2671e+00,
5.1617e+00, -1.8706e+00]],
[[-5.8163e+00, -6.1726e+00, -8.6955e-01, ..., 2.4110e+00,
3.6176e+00, 3.8898e-01],
[-5.6451e+00, 6.5307e+00, -2.9575e+00, ..., -2.4792e+00,
-1.7128e+00, 3.4192e-01],
[-1.2541e+00, 5.5193e+00, -1.5271e+00, ..., -3.2821e+00,
-2.2629e+00, 2.3459e+00],
...,
[ 4.4929e+00, -1.6952e+00, 6.0471e-02, ..., 6.2957e-01,
2.3852e+00, -2.5908e+00],
[-2.9189e+00, 8.9482e+00, 3.1285e+00, ..., 1.9665e+00,
-7.9941e+00, 6.7249e+00],
[ 2.1962e+00, 4.8124e-01, -2.4982e+00, ..., -3.0166e+00,
5.4365e+00, -2.7225e+00]]],
[[[-7.5283e-01, 8.5852e-01, -5.0911e+00, ..., -4.4364e+00,
-2.0836e+00, 9.6506e-01],
[ 7.6227e+00, -6.0748e+00, 9.9688e-02, ..., 5.0705e+00,
-2.5331e+00, 1.2098e+00],
[-9.1037e+00, -6.3573e+00, -1.1586e+00, ..., -5.5275e+00,
-4.0010e+00, 1.1728e-01],
...,
[ 9.8956e+00, -1.0136e+01, -5.7064e-01, ..., 2.9236e+00,
4.7774e+00, 4.7892e+00],
[-1.1550e+00, -1.2981e+01, -4.2191e-01, ..., -8.7361e+00,
1.2216e-01, 6.5822e-01],
[ 4.6629e+00, -5.3685e-02, 3.9027e+00, ..., -2.5676e+00,
-6.7082e+00, -1.2024e+00]],
[[ 6.4764e-02, 2.9887e+00, -5.5278e-01, ..., -2.4504e+00,
2.6865e+00, 1.5455e+00],
[ 3.8693e+00, 4.9775e+00, -7.0506e-01, ..., -6.5763e+00,
-2.6643e+00, 8.4404e-01],
[ 5.2624e+00, 2.3685e+00, 2.0436e+00, ..., -1.5566e+00,
-7.4885e+00, 2.6299e+00],
...,
[ 1.3971e+00, 4.2436e+00, 1.8150e+00, ..., 2.0391e+00,
-2.9623e+00, 5.4761e-01],
[ 8.8958e+00, -7.7939e+00, 8.7716e+00, ..., -2.2277e+00,
-1.9283e+00, 2.1222e-01],
[ 4.5110e+00, 5.7886e-01, 1.0928e+01, ..., -5.9391e-02,
1.0715e+01, -3.3883e+00]],
[[-2.3210e-02, 4.7797e+00, 7.1065e+00, ..., 1.0102e+01,
6.0278e+00, 5.2539e+00],
[ 6.7944e-01, -3.4759e+00, -1.1673e+01, ..., 7.2417e+00,
-1.1882e+00, -5.2319e+00],
[-4.1916e+00, -1.6213e+00, -4.6520e+00, ..., -2.1633e+00,
7.6709e+00, -1.9986e+00],
...,
[-2.6971e-01, -4.5031e+00, -1.6535e+01, ..., -6.0594e+00,
3.3653e+00, 4.9447e+00],
[-9.8747e+00, 8.6001e+00, -9.4101e-01, ..., 3.3066e+00,
-2.2439e+00, -5.2685e+00],
[-4.2541e+00, 7.3306e+00, -5.8639e+00, ..., 4.7279e-01,
-2.7987e+00, 3.8390e+00]],
...,
[[-2.7434e+00, -2.6165e+00, 2.7495e+00, ..., -4.8986e+00,
-6.5936e-01, -1.2209e+00],
[ 6.0242e+00, 1.7905e+00, -2.0004e+00, ..., -4.3717e+00,
2.4835e+00, 4.2826e+00],
[-1.5119e+00, 2.6224e+00, -3.9190e-01, ..., -9.7440e+00,
-1.0417e+00, -9.0039e-01],
...,
[-9.4116e+00, -4.8442e-02, -6.4542e-01, ..., 3.8573e-01,
-2.2019e+00, -2.6884e+00],
[-4.9596e+00, -4.1030e+00, 1.0781e+01, ..., 4.8259e+00,
-2.8701e-01, -7.9997e+00],
[-1.8130e+00, 2.0495e+00, 9.2997e+00, ..., 8.0299e+00,
9.6737e+00, -1.2031e+00]],
[[ 2.3153e-01, 4.6089e+00, 2.0819e-01, ..., 1.2207e+00,
2.5573e+00, 6.2615e-01],
[ 1.8204e+00, 4.5090e-01, -6.5793e+00, ..., -7.3586e+00,
-5.7058e+00, 2.1532e+00],
[ 1.2211e+00, 5.8077e+00, 9.1297e-01, ..., -2.0026e+00,
4.4427e-02, 1.4148e+00],
...,
[-5.1236e+00, -6.8225e+00, -5.0034e+00, ..., -8.3465e-01,
-1.3478e-01, 4.0831e+00],
[-1.9589e-01, -7.5132e+00, 9.5074e+00, ..., -1.8749e+00,
6.0875e+00, -8.8256e+00],
[ 3.4282e+00, -1.0206e+01, 4.0873e+00, ..., 2.7558e+00,
1.3918e+01, 5.0699e+00]],
[[-2.6489e+00, -3.0597e+00, 6.4664e+00, ..., -2.3190e-01,
1.3294e+00, 3.2779e+00],
[ 2.0849e+00, 1.8061e+00, -1.6147e+00, ..., 2.5412e+00,
-1.5299e+00, 1.5512e+00],
[ 1.0366e+00, -5.9907e-01, 3.6010e+00, ..., -5.0889e+00,
3.9155e+00, -5.9709e+00],
...,
[-6.6184e+00, -1.5406e+00, 2.4049e+00, ..., -2.4464e+00,
4.8486e+00, -7.5602e-01],
[ 9.3436e-01, -2.6668e+00, -4.5156e+00, ..., 5.6540e+00,
-8.7479e+00, -4.2103e+00],
[ 1.4970e+00, -3.7362e+00, 4.0498e+00, ..., -2.3107e+00,
1.2149e+00, 1.2334e+00]]],
...,
[[[-2.8469e+00, -4.3396e-01, 2.7531e+00, ..., -2.1063e+00,
5.6515e+00, 4.0716e+00],
[ 2.5349e+00, 1.3879e+00, 1.6890e-01, ..., 3.1700e+00,
8.8535e-01, -8.2570e+00],
[-7.8680e+00, 5.3083e+00, 5.2212e+00, ..., -2.9605e+00,
-6.7691e+00, -3.1563e+00],
...,
[ 3.6163e+00, -1.6915e-03, 1.2774e+00, ..., -1.0088e+01,
-2.5499e+00, -4.0746e+00],
[-1.0391e+01, 4.3187e+00, 8.7783e-01, ..., 8.9951e+00,
-1.1735e+01, -9.1920e+00],
[ 2.1597e+00, -2.2373e+00, -9.9057e+00, ..., -3.3842e+00,
-5.7491e+00, 2.2495e+00]],
[[-3.1528e+00, 4.1759e+00, 9.5523e-01, ..., -1.0372e+01,
-8.8251e+00, -4.4081e-01],
[-3.5309e+00, -4.5082e+00, 3.5663e+00, ..., -2.6229e+00,
-7.7337e+00, -7.8045e+00],
[ 6.0639e+00, -2.8855e+00, 6.5672e+00, ..., 9.3466e-01,
5.1705e+00, -1.9107e+00],
...,
[ 4.6854e+00, -5.6821e+00, 2.1637e+00, ..., 1.0989e+01,
-7.6437e+00, 2.7010e+00],
[ 2.4157e+00, -5.4816e+00, 2.6400e-01, ..., -3.6340e+00,
8.9612e+00, -3.6934e+00],
[-5.6887e+00, 6.1219e+00, -4.8584e-02, ..., 1.6607e+00,
3.9719e+00, 1.6241e+00]],
[[ 2.7604e-01, -2.9959e+00, 1.9993e+00, ..., 1.3666e+01,
9.1424e+00, 3.2718e+00],
[ 4.0714e+00, 4.8943e+00, -7.6580e+00, ..., 7.2379e+00,
8.6255e+00, 9.2457e+00],
[-4.7303e+00, 5.0782e+00, 5.9824e+00, ..., 1.7659e+00,
-8.1287e+00, -2.8038e-02],
...,
[-2.3675e-01, 7.1781e+00, 8.4681e-01, ..., -1.0991e+01,
9.3842e+00, 1.6333e+00],
[-8.0973e-01, 4.3547e+00, -1.1691e+00, ..., 2.8027e+00,
-6.9913e+00, -1.3433e+00],
[ 8.2646e-01, -4.4522e+00, 2.6599e+00, ..., -4.0014e+00,
-3.2059e+00, -1.4797e+00]],
...,
[[ 1.4799e+00, -1.8384e+00, -4.0680e+00, ..., -8.3930e+00,
-8.5237e+00, -1.0532e+01],
[-4.0885e+00, 6.3336e+00, 4.3864e+00, ..., -4.5254e+00,
-5.7292e+00, 1.0106e+00],
[ 4.0594e+00, -3.0776e+00, -2.8569e+00, ..., 2.2189e+00,
1.1134e+00, 1.6195e+00],
...,
[ 6.4229e-01, -3.7812e+00, -1.7915e+00, ..., -1.2055e+00,
-7.0071e+00, -3.2209e+00],
[ 3.2794e+00, 4.9391e+00, 1.3938e+00, ..., -1.8064e+00,
-1.0710e+00, 1.3632e+01],
[-2.7384e+00, -2.0693e+00, 4.2476e-01, ..., -3.7770e+00,
4.3118e+00, -3.1464e+00]],
[[ 2.6091e+00, 2.4955e+00, 1.3720e+00, ..., -2.3163e+00,
-5.6930e+00, -3.9098e+00],
[ 5.5788e-01, -4.3494e+00, 3.4477e+00, ..., -2.0413e+00,
-4.9846e-01, -1.3748e+00],
[ 1.7158e+00, 6.0277e+00, 1.0304e+00, ..., 6.3121e-01,
7.9953e-01, 2.9191e+00],
...,
[-6.6152e+00, -5.0639e+00, 7.4051e+00, ..., 3.1339e+00,
5.3821e+00, -5.0544e+00],
[ 1.3957e+00, 3.3124e+00, 3.9395e-01, ..., -4.9781e+00,
-5.2482e+00, 8.6575e+00],
[-2.1751e+00, 4.5792e+00, 2.5694e+00, ..., 6.2684e+00,
2.1828e+00, -3.0891e+00]],
[[ 1.8311e+00, -1.0445e+00, -3.7354e-01, ..., -6.1123e+00,
-1.6218e+00, -3.5521e+00],
[-3.5934e+00, 5.6481e+00, 1.4308e+00, ..., -6.0433e+00,
-4.9458e+00, 1.4269e+00],
[ 5.4159e+00, 1.0940e+00, -4.5837e+00, ..., 4.4833e+00,
2.4283e+00, -2.6979e+00],
...,
[-5.5705e+00, 3.2417e+00, -2.1181e+00, ..., -7.8896e-01,
2.1573e+00, -8.6771e+00],
[ 5.7076e+00, 1.6248e-01, -5.5152e+00, ..., 6.9059e+00,
-5.2762e+00, 9.0609e+00],
[-6.2253e-01, -5.4573e+00, 3.4715e+00, ..., 1.0480e+00,
6.1495e+00, -1.8296e+00]]],
[[[ 6.3926e+00, -6.8443e+00, -2.2578e+00, ..., 4.5002e+00,
2.2434e-01, -1.9890e+00],
[-6.9565e+00, -5.5449e+00, 1.0627e-02, ..., 3.7578e+00,
7.4389e+00, -1.1877e+00],
[-6.7506e+00, -2.7120e+00, 9.0764e-01, ..., -6.3217e+00,
-8.1740e+00, -2.3924e+00],
...,
[-6.9108e+00, -5.3317e+00, -3.8049e+00, ..., 5.4360e+00,
-1.7020e+00, -9.9488e-02],
[-7.3857e+00, -6.2368e+00, 5.4051e+00, ..., -9.1315e+00,
1.1444e+00, -1.1833e+00],
[ 5.0968e+00, 3.9464e+00, 1.3267e+00, ..., -3.7275e-01,
-4.5882e+00, 3.6944e-01]],
[[-1.0397e+00, -4.9882e+00, 9.6265e-01, ..., 5.7550e+00,
-5.4967e+00, -4.9923e+00],
[ 1.5719e+00, -6.5812e+00, 4.0412e+00, ..., -3.7028e+00,
-8.2523e+00, 6.5658e+00],
[ 3.3117e+00, 1.5074e+00, -1.9119e+00, ..., -6.2489e+00,
-4.4134e+00, -8.2356e-02],
...,
[ 1.9465e+00, 2.6776e+00, 4.2820e+00, ..., 1.1308e+01,
5.8188e-01, 1.0871e+00],
[ 4.6406e+00, 3.2740e+00, 4.3403e+00, ..., 1.6051e+00,
8.8572e+00, -8.5234e-01],
[ 1.8414e+00, -4.6944e-01, 6.3242e+00, ..., 2.3157e+00,
1.0094e+01, 4.6535e-01]],
[[ 3.0897e+00, 1.9476e+00, -6.4742e+00, ..., 3.5238e+00,
1.4618e+01, 8.1632e+00],
[-2.3846e+00, 2.3932e+00, -4.7569e+00, ..., 1.3420e+01,
9.3195e+00, -4.2973e+00],
[ 2.1813e+00, 4.8082e+00, 8.3405e+00, ..., -4.9620e+00,
-5.5926e+00, -5.5918e+00],
...,
[ 1.7750e+00, -5.7940e+00, -8.0359e+00, ..., -1.0794e+01,
1.9923e+00, -3.3641e+00],
[-1.1593e+00, 4.4678e+00, 7.5956e+00, ..., 5.7736e-01,
-4.6111e+00, 2.9936e-01],
[-5.6503e+00, 7.4724e-01, -6.2385e+00, ..., 3.7532e+00,
-4.3002e+00, -3.8317e-01]],
...,
[[-5.0502e+00, 2.1766e+00, 1.9390e+00, ..., -6.8034e+00,
-1.0064e+01, -5.9378e+00],
[ 8.9868e-02, -4.4274e+00, 8.5459e-01, ..., -4.4210e+00,
-6.7069e+00, 1.8132e+00],
[-6.5993e+00, 2.4160e+00, -4.4765e-01, ..., -4.6350e+00,
-6.5165e+00, -3.7480e+00],
...,
[ 6.5151e+00, 6.5357e+00, 5.1762e+00, ..., 6.1036e+00,
-4.3038e+00, 6.1620e-02],
[ 3.5904e+00, 1.1178e+01, 5.0263e-01, ..., 1.0380e+01,
7.9407e+00, 3.1942e-01],
[-2.4936e+00, 1.3312e+00, 9.8196e+00, ..., 2.0378e+00,
9.8174e+00, -3.6361e+00]],
[[-2.3208e+00, -5.0818e+00, 2.7327e-01, ..., 4.3287e+00,
-5.3598e+00, 3.5888e-01],
[ 6.3854e+00, -3.4766e+00, -8.3864e-01, ..., -1.9109e+00,
-8.3865e+00, -2.4428e-01],
[ 8.8913e-01, -1.7154e+00, 4.2767e+00, ..., 9.4425e-01,
-2.8084e+00, -5.8492e+00],
...,
[-2.6031e-01, 3.5485e+00, 7.3868e+00, ..., 1.1694e+01,
-6.8299e+00, 2.9112e+00],
[-4.2264e+00, 7.5868e+00, 3.9403e+00, ..., 8.1379e+00,
7.6175e+00, -1.1209e+00],
[-1.9657e+00, -5.3601e+00, 5.5462e+00, ..., 1.7602e+00,
1.1651e+01, -1.2540e+00]],
[[-5.4279e+00, 3.4956e+00, -1.8296e+00, ..., -1.0957e+01,
7.0151e+00, -3.9637e+00],
[ 3.5340e+00, -6.5948e+00, -2.4994e+00, ..., 5.4032e+00,
-1.0914e+00, -2.7012e+00],
[-4.6624e+00, 6.2528e+00, 1.3535e+00, ..., -1.6052e+00,
-6.4739e+00, -3.9513e+00],
...,
[ 2.1082e+00, -2.5364e+00, 4.4941e+00, ..., 1.9903e+00,
-5.5176e+00, -6.8988e+00],
[ 2.5837e+00, 8.7233e-01, -6.5186e+00, ..., 4.9339e+00,
-5.6282e-01, 1.6822e+00],
[-1.0175e+00, -3.4119e+00, -4.6884e-01, ..., 2.3967e+00,
3.4213e+00, 1.7320e-01]]],
[[[-1.1794e+01, 1.9706e+00, -4.5717e-01, ..., 2.9642e+00,
4.6716e-01, 1.1350e+00],
[ 1.1605e+00, -3.1091e+00, -6.2993e+00, ..., -3.1896e+00,
1.7918e+00, -6.4183e-01],
[-7.6669e+00, 3.1258e-01, 4.4841e+00, ..., -6.3737e+00,
-7.9268e+00, -4.2075e+00],
...,
[-3.0366e+00, -8.8625e+00, -8.0941e-01, ..., 4.2424e+00,
2.2004e+00, 2.5828e+00],
[ 3.2349e+00, 1.7503e+00, 2.4199e+00, ..., -5.1116e+00,
-2.7019e-01, 3.4464e+00],
[-2.1811e+00, -3.3718e+00, -4.9047e+00, ..., -1.5869e-01,
-8.2334e+00, -1.0708e+00]],
[[-1.6922e+00, -3.3522e+00, 2.3788e+00, ..., 1.3471e+00,
4.4125e+00, 2.3164e+00],
[ 3.4394e+00, 7.4745e+00, -3.1905e+00, ..., 2.1268e+00,
1.2921e+00, -3.7040e+00],
[ 6.0023e+00, -3.9252e+00, -5.9118e+00, ..., -1.2001e-01,
9.9228e+00, -5.2126e+00],
...,
[ 5.6829e+00, 2.4160e+00, -5.3530e+00, ..., -2.0013e+00,
-8.8902e-01, 8.2702e-01],
[-3.4181e+00, 1.8784e+00, -1.2851e-01, ..., -6.3484e+00,
-1.1095e+00, 3.1933e+00],
[ 1.7743e+00, -4.2458e+00, 2.7130e+00, ..., 1.7191e+00,
-1.6987e+00, 1.1448e-01]],
[[ 5.4964e+00, 3.1900e+00, 6.8018e+00, ..., 3.2611e+00,
4.4314e-01, 1.3264e+00],
[-4.0317e+00, -5.7015e+00, 7.5788e+00, ..., -5.1331e+00,
-6.4944e+00, 1.8314e+00],
[ 6.3879e-01, 4.3213e+00, 2.0533e+00, ..., -1.5493e+00,
-4.5688e+00, 3.4963e+00],
...,
[-1.2391e+00, 4.0005e+00, 1.2533e+01, ..., 8.0268e+00,
1.0243e+01, -1.0580e+00],
[ 2.5657e+00, -2.7931e+00, -5.0291e+00, ..., 6.4889e+00,
7.3912e+00, -6.0190e+00],
[-8.6158e-01, 2.0136e+00, -2.8210e+00, ..., -4.1525e+00,
-2.6724e-01, 4.6531e+00]],
...,
[[-4.3992e-02, 1.7310e+00, -6.0915e-01, ..., -3.2150e-01,
-1.5139e+00, 2.0569e-01],
[ 1.0048e+01, 2.4600e+00, 5.9582e+00, ..., 8.8527e+00,
-1.7471e+00, -3.9204e+00],
[-3.2207e+00, -4.5207e-01, -1.9105e+00, ..., 4.4872e+00,
7.5106e+00, -1.3993e+00],
...,
[ 2.2103e+00, -5.0492e+00, 4.3683e+00, ..., -2.6982e-01,
6.2380e-02, -1.0609e+00],
[ 2.9357e-01, 4.4318e+00, 4.8453e+00, ..., 3.0241e+00,
1.2829e+01, 6.5143e+00],
[-4.5640e+00, -5.7423e+00, -5.8182e+00, ..., -3.7002e+00,
4.5583e+00, 5.7046e+00]],
[[ 1.5951e+00, 5.8857e+00, 1.2614e+00, ..., 3.1424e-01,
1.5241e+00, 1.2705e+00],
[-6.2557e+00, 7.5877e+00, 4.2510e+00, ..., -1.7941e+00,
2.4938e+00, -7.9362e+00],
[ 1.0116e+00, 2.3964e+00, -2.1098e+00, ..., 4.1514e+00,
7.6735e+00, 3.6324e+00],
...,
[ 7.7750e-02, 5.2841e+00, 1.0294e+00, ..., 3.9948e+00,
-6.3360e-01, -2.6283e+00],
[ 5.8292e-01, -3.9268e+00, -1.3456e+00, ..., 1.3299e+00,
2.0126e+00, 3.2471e+00],
[ 5.0946e+00, -1.8059e+00, 1.1004e+00, ..., -4.4860e+00,
3.6343e+00, 4.1826e+00]],
[[ 2.9746e+00, -4.1794e-01, -1.9168e+00, ..., 1.0937e+00,
-1.4297e+00, 5.4064e-01],
[-1.7542e+00, 5.5155e+00, 5.0044e+00, ..., 3.6544e-01,
-6.6498e+00, 1.1677e+00],
[-1.0060e+00, 3.9387e+00, 2.8243e-02, ..., -3.7306e+00,
3.6330e+00, 1.3868e+00],
...,
[ 2.6675e+00, 1.8058e+00, 3.9083e+00, ..., -3.1750e+00,
-1.9479e+00, -1.0970e+00],
[-8.9554e-01, 1.7135e+00, -4.5430e-01, ..., 5.9040e+00,
2.5242e+00, -3.2063e+00],
[-7.2182e-01, -1.7900e+00, 8.4169e-01, ..., -2.6818e+00,
-2.5198e+00, 1.9983e+00]]]]) torch.Size([64, 64, 244, 244])
|구분|nn.xx|nn.functional.xx| ----|-----|----------------| |형태| nn.Conv2d:클래스
nn.Module: 클래스를 상속받아 사용| nn.functional.conv2d:함수
def function(input)으로 정의된 순수한 함수| |호출 방법| 먼저 하이퍼파라미터를 전달한 후 함수 호출을 통해 데이터 전달| 함수를 호출할 때 하이퍼파라미터, 데이터 전달| |위치| nn.Sequential 내에 위치| nn.Sequential에 위치할 수 없음| |파라미터| 파라미터를 새로 정의할 필요 없음| 가중치를 수동으로 전달해야 할 때마다 자체 가중치를 정의|
# 학습을 위해 loss function, learning rate, optimizer 정의
learning_rate = 0.001;
model = FashionDNN();
model.to(device)
criterion = nn.CrossEntropyLoss(); # 분류 문제에서 사용하는 loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate);
print(model)
FashionDNN(
(fc1): Linear(in_features=784, out_features=256, bias=True)
(drop): Dropout2d(p=0.25, inplace=False)
(fc2): Linear(in_features=256, out_features=128, bias=True)
(fc3): Linear(in_features=128, out_features=10, bias=True)
)
num_epochs = 5
count = 0
loss_list = []
iteration_list = []
accuracy_list = []
predictions_list = []
labels_list = []
for epoch in range(num_epochs): # 5번 반복
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
train = Variable(images.view(100, 1, 28, 28)) # autograd는 Variable을 사용해서 역전파를 위한
# 미분 값을 자동으로 계산
labels = Variable(labels)
outputs = model(train) # train data를 model에 적용
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
count += 1
if not (count % 50):
total = 0
correct = 0
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
labels_list.append(labels)
test = Variable(images.view(100, 1, 28, 28))
outputs = model(test)
predictions = torch.max(outputs, 1)[1].to(device)
predictions_list.append(predictions)
correct += (predictions == labels).sum()
total += len(labels)
accuracy = correct * 100 / total
loss_list.append(loss.data)
iteration_list.append(count)
accuracy_list.append(accuracy)
if not (count % 500):
print("Iteration: {}, Loss: {}, Accuracy: {}%".format(count, loss.data, accuracy))
Iteration: 500, Loss: 0.5853009223937988, Accuracy: 83.22000122070312%
Iteration: 1000, Loss: 0.4784613847732544, Accuracy: 84.54999542236328%
Iteration: 1500, Loss: 0.33210045099258423, Accuracy: 84.55999755859375%
Iteration: 2000, Loss: 0.3649474084377289, Accuracy: 85.38999938964844%
Iteration: 2500, Loss: 0.2886713445186615, Accuracy: 86.43999481201172%
Iteration: 3000, Loss: 0.3703139126300812, Accuracy: 86.16999816894531%
- 정확도가 80% 이상이었다고 하자. 하지만 80%라는 값이 모든 클래스가 동등하게 고려된 것인지, 특성 클래스의 분류가 높았던 것인지에 대해 알 수 없음을 유의하자
- 정확도가 90% 이상이었다고 하자. 하지만 100개의 데이터 중 90개가 하나의 클래스에 속할 경우 90%의 정확도는 높다고 할 수 없다.
즉, 모든 데이터를 특정 클래스에 속한다고 예측해도 90%의 예측 결과가 나오기 때문에 데이터 특성에 따라 정확도를 잘 관측해야 한다.
# 합성곱 신경망 생성
class FashionCNN(nn.Module):
def __init__(self):
super(FashionCNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
self.drop = nn.Dropout2d(0.25)
self.fc2 = nn.Linear(in_features=600, out_features=120)
self.fc3 = nn.Linear(in_features=120, out_features=10) # output 계층의 출력 결과가 out_feature
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.drop(out)
out = self.fc2(out)
out = self.fc3(out)
return out
learning_rate = 0.001;
model = FashionCNN();
model.to(device)
criterion = nn.CrossEntropyLoss();
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate);
print(model)
FashionCNN(
(layer1): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(layer2): Sequential(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fc1): Linear(in_features=2304, out_features=600, bias=True)
(drop): Dropout2d(p=0.25, inplace=False)
(fc2): Linear(in_features=600, out_features=120, bias=True)
(fc3): Linear(in_features=120, out_features=10, bias=True)
)
num_epochs = 5
count = 0
loss_list = []
iteration_list = []
accuracy_list = []
predictions_list = []
labels_list = []
for epoch in range(num_epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
train = Variable(images.view(100, 1, 28, 28))
labels = Variable(labels)
outputs = model(train)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
count += 1
if not (count % 50):
total = 0
correct = 0
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
labels_list.append(labels)
test = Variable(images.view(100, 1, 28, 28))
outputs = model(test)
predictions = torch.max(outputs, 1)[1].to(device)
predictions_list.append(predictions)
correct += (predictions == labels).sum()
total += len(labels)
accuracy = correct * 100 / total
loss_list.append(loss.data)
iteration_list.append(count)
accuracy_list.append(accuracy)
if not (count % 500):
print("Iteration: {}, Loss: {}, Accuracy: {}%".format(count, loss.data, accuracy))
Iteration: 500, Loss: 0.4643004238605499, Accuracy: 87.91999816894531%
Iteration: 1000, Loss: 0.32723572850227356, Accuracy: 88.68999481201172%
Iteration: 1500, Loss: 0.3142659366130829, Accuracy: 87.81999969482422%
Iteration: 2000, Loss: 0.24007703363895416, Accuracy: 89.75999450683594%
Iteration: 2500, Loss: 0.13590288162231445, Accuracy: 90.31999969482422%
Iteration: 3000, Loss: 0.18765489757061005, Accuracy: 90.18999481201172%
심층 신경망과 비교하여 정확도가 약간 높습니다. 심층 신경망과 별 차이가 없기 때문에 좀 더 간편한 심층 신경망만 사용해도 무난할 것 같지만 실제로 이미지 데이터가 많아지면 단순 심층 신경망으로는 정확한 특성 추출 및 분류가 불가능하므로 합성곱 신경망을 생성할 수 있도록 학습해야 함
5.3 전이 학습(transfer learning)¶
- 일반적으로 합성곱 신경망 기반의 딥러닝 모델을 제대로 훈련시키려면 많은 양의 데이터가 필요
- 현실적으로 많은 양의 데이터를 구하기 어려워 전이 학습을 사용
- 전이 학습: 이미지넷(ImageNet)처럼 아주 큰 데이터셋을 써서 훈련된 모델의 가중치를 가져와 해결하려는 과제에 맞게 보정해서 사용하는 것
- 전이 학습을 위한 방법으로 특성 추출과 미세 조정 기법이 존재
특성 추출 기법(feature extractor)¶
- ImageNet 데이터셋으로 사전 훈련된 모델을 가져온 후 마지막에 완전연결층 부분만 새로 만듬
- 학습 시 마지막 완전연결층(이미지의 카테고리르 결정하는 부분)만 학습하고 나머지 계층들은 학습되지 않도록 함
- 특성 추출은 이미지 분류를 위해 두 부분으로 구성
- 합성곱층: 합성곱층과 풀링층으로 구성
- 데이터 분류기(완전연결층): 추출된 특성을 입력받아 최종적으로 이미지에 대한 클래스를 분류하는 부분
import os
import time
import copy
import glob
import cv2
import shutil
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms # 데이터 전처리를 위해 사용되는 패키지
import torchvision.models as models # 다양한 네트워크를 사용할 수 있도록 도와주는 패키지
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
data_path = './data/catanddog/train' # 이미지 데이터 경로
transform = transforms.Compose(
# 이미지 데이터를 변환하여 모델(네트워크)의 입력으로 사용할 수 있게 변환
[
transforms.Resize([256, 256]),
# Resize: 이미지의 크기를 조정
transforms.RandomResizedCrop(224),
# 이미지를 랜덤한 크기 및 비율로 자름
transforms.RandomHorizontalFlip(),
# 이미지를 랜덤하게 수평으로 뒤집음
transforms.ToTensor(),
# 이미지 데이터를 텐서로 변환
])
train_dataset = torchvision.datasets.ImageFolder(
# 데이터로더가 데이터를 불러올 대상과 방법을 정의
data_path,
transform = transform
)
train_loader = torch.utils.data.DataLoader(
# ImageFolder를 데이터로더로 할당
# batch_size: 지정
# num_workers: 하위 프로세스를 몇 개 사용할 지 지정
# shuffle
train_dataset,
batch_size = 32,
num_workers = 8,
shuffle = True
)
print(len(train_dataset))
385
samples, labels = iter(train_loader).next() # iter를 사용해 반복자를 구하고 next()는 다음 출력을 구함
classes = {0:'cat', 1:'dog'}
fig = plt.figure(figsize=(16,24))
for i in range(24):
a = fig.add_subplot(4,6,i+1)
a.set_title(classes[labels[i].item()])
a.axis('off')
print(samples[i].shape)
a.imshow(np.transpose(samples[i].numpy(), (1,2,0)))
# np.transpose를 사용해 행과 열을 변경해 행렬의 차원을 변경(전치행렬)
plt.subplots_adjust(bottom=0.2, top=0.6, hspace=0)
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
resnet18 = models.resnet18(pretrained=True)
# pretrained된 model 호추
def set_parameter_requires_grad(model, feature_extracting=True):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
# 위 함수를 이용해 backpropagation에 값 변경 x
set_parameter_requires_grad(resnet18)
resnet18.fc = nn.Linear(512, 2)
# pretrained된 model에 완전연결층 추가(class 2개)
for name, param in resnet18.named_parameters():
if param.requires_grad:
print(name, param.data)
# 결과를 보면 fc.weight와 fc.bias만 return된다
fc.weight tensor([[-0.0250, -0.0163, -0.0392, ..., 0.0188, -0.0078, 0.0188],
[ 0.0369, -0.0262, -0.0118, ..., -0.0407, -0.0129, -0.0134]])
fc.bias tensor([-0.0376, -0.0180])
model = models.resnet18(pretrained = True)
# 모델 생성
for param in model.parameters():
param.requires_grad = False
# train x
model.fc = torch.nn.Linear(512, 2)
for param in model.fc.parameters():
param.requires_grad = True
# 모델의 fc만 train
optimizer = torch.optim.Adam(model.fc.parameters())
cost = torch.nn.CrossEntropyLoss()
print(model)
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=2, bias=True)
)
def train_model(model, dataloaders, criterion, optimizer, device, num_epochs=13, is_train=True):
since = time.time()
acc_history = []
loss_history = []
best_acc = 0.0
for epoch in range(num_epochs): # 13번 반복
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders:
inputs = inputs.to(device)
labels = labels.to(device)
model.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
# 출력 결과와 레이블의 오차를 계산한 결과를 누적해 저장
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders.dataset)
epoch_acc = running_corrects.double() / len(dataloaders.dataset)
print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
if epoch_acc > best_acc:
best_acc = epoch_acc
acc_history.append(epoch_acc.item())
loss_history.append(epoch_loss)
torch.save(model.state_dict(), os.path.join('./data/catanddog/', '{0:0=2d}.pth'.format(epoch)))
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best Acc: {:4f}'.format(best_acc))
return acc_history, loss_history
params_to_update = []
for name,param in resnet18.named_parameters():
if param.requires_grad == True:
params_to_update.append(param)
print("\t",name)
optimizer = optim.Adam(params_to_update)
# 학습 결과 optimizer에 전달
fc.weight
fc.bias
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
train_acc_hist, train_loss_hist = train_model(resnet18, train_loader, criterion, optimizer, device)
Epoch 0/12
----------
Loss: 0.5420 Acc: 0.7506
Epoch 1/12
----------
Loss: 0.3865 Acc: 0.8338
Epoch 2/12
----------
Loss: 0.3024 Acc: 0.8753
Epoch 3/12
----------
Loss: 0.2726 Acc: 0.9091
Epoch 4/12
----------
Loss: 0.2534 Acc: 0.9195
Epoch 5/12
----------
Loss: 0.2825 Acc: 0.8831
Epoch 6/12
----------
Loss: 0.2167 Acc: 0.9221
Epoch 7/12
----------
Loss: 0.2389 Acc: 0.9065
Epoch 8/12
----------
Loss: 0.1874 Acc: 0.9325
Epoch 9/12
----------
Loss: 0.2228 Acc: 0.9143
Epoch 10/12
----------
Loss: 0.2781 Acc: 0.8701
Epoch 11/12
----------
Loss: 0.3092 Acc: 0.8649
Epoch 12/12
----------
Loss: 0.3099 Acc: 0.8597
Training complete in 1m 8s
Best Acc: 0.932468
# test dataset 생성
test_path = './data/catanddog/test'
transform = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
test_dataset = torchvision.datasets.ImageFolder(
root=test_path,
transform=transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=32,
num_workers=1,
shuffle=True
)
print(len(test_dataset))
98
def eval_model(model, dataloaders, device):
since = time.time()
acc_history = []
best_acc = 0.0
saved_models = glob.glob('./data/catanddog/' + '*.pth')
saved_models.sort()
print('saved_model', saved_models) # 저장된 pth 파일들 출력
for model_path in saved_models:
print('Loading model', model_path)
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(device)
running_corrects = 0
for inputs, labels in dataloaders: # test 반복
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad(): # 학습 x
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
preds[preds >= 0.5] = 1 # predict가 0.5보다 크면 1
preds[preds < 0.5] = 0 # 작으면 0
running_corrects += preds.eq(labels.to(device)).int().sum()
epoch_acc = running_corrects.double() / len(dataloaders.dataset)
print('Acc: {:.4f}'.format(epoch_acc))
if epoch_acc > best_acc:
best_acc = epoch_acc
acc_history.append(epoch_acc.item())
print()
time_elapsed = time.time() - since
print('Validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best Acc: {:4f}'.format(best_acc))
return acc_history
val_acc_hist = eval_model(resnet18, test_loader, device)
saved_model ['./data/catanddog\\00.pth', './data/catanddog\\01.pth', './data/catanddog\\02.pth', './data/catanddog\\03.pth', './data/catanddog\\04.pth', './data/catanddog\\05.pth', './data/catanddog\\06.pth', './data/catanddog\\07.pth', './data/catanddog\\08.pth', './data/catanddog\\09.pth', './data/catanddog\\10.pth', './data/catanddog\\11.pth', './data/catanddog\\12.pth']
Loading model ./data/catanddog\00.pth
Acc: 0.9082
Loading model ./data/catanddog\01.pth
Acc: 0.9286
Loading model ./data/catanddog\02.pth
Acc: 0.9694
Loading model ./data/catanddog\03.pth
Acc: 0.9490
Loading model ./data/catanddog\04.pth
Acc: 0.9490
Loading model ./data/catanddog\05.pth
Acc: 0.9388
Loading model ./data/catanddog\06.pth
Acc: 0.9694
Loading model ./data/catanddog\07.pth
Acc: 0.9184
Loading model ./data/catanddog\08.pth
Acc: 0.9694
Loading model ./data/catanddog\09.pth
Acc: 0.9592
Loading model ./data/catanddog\10.pth
Acc: 0.9184
Loading model ./data/catanddog\11.pth
Acc: 0.9592
Loading model ./data/catanddog\12.pth
Acc: 0.9184
Validation complete in 0m 20s
Best Acc: 0.969388
plt.plot(train_acc_hist)
plt.plot(val_acc_hist)
plt.show()
plt.plot(train_loss_hist)
plt.show()
# 예측 결과를 이미지로 표현하기 위한 전처리 함수
def im_convert(tensor):
image=tensor.clone().detach().numpy() # tensor.clone은 tensor를 복사하는 함수
image=image.transpose(1,2,0)
image=image*(np.array((0.5,0.5,0.5))+np.array((0.5,0.5,0.5)))
image=image.clip(0,1) # clip은 범위를 벗어날 때 input을 특정 범위로 제한시키는 함수
return image
|구분|메모리|계산 그래프 상주 유무| -----|------|---------------------| |tensor.clone()|새롭게 할당|계산 그래프에 계속 상주| |tensor.detach()|공유해서 사용|계산 그래프에 상주하지 않음| |tensor.clone().detach()|새롭게 할당|계산 그래프에 상주하지 않음|
import numpy as np
exam = np.array([-1.8, -1.2, -0.7, 0.0, 0.8, 1.4, 1.9])
print(exam)
print(np.clip(exam, -0.5, 0.5))
[-1.8 -1.2 -0.7 0. 0.8 1.4 1.9]
[-0.5 -0.5 -0.5 0. 0.5 0.5 0.5]
classes = {0:'cat', 1:'dog'}
dataiter=iter(test_loader) # test dataset 불러오기
images,labels=dataiter.next() # test image의 img와 label 가져오기
output=model(images)
_,preds=torch.max(output,1)
fig=plt.figure(figsize=(25,4))
for idx in np.arange(20):
ax=fig.add_subplot(2,10,idx+1,xticks=[],yticks=[])
plt.imshow(im_convert(images[idx]))
a.set_title(classes[labels[i].item()])
ax.set_title("{}({})".format(str(classes[preds[idx].item()]),str(classes[labels[idx].item()])),color=("green" if preds[idx]==labels[idx] else "red"))
plt.show()
plt.subplots_adjust(bottom=0.2, top=0.6, hspace=0)
# 에측결과와 data의 label 순서로 title 작성
# 결과가 그닥 좋지 않은 것을 볼 수 있고 이를 해결하기 위해 train dataset을 늘리고 epoch을 늘리면 된다.
<Figure size 640x480 with 0 Axes>
5.3.2 미세 조정(fine-tuning) 기법¶
- 특성 추출 기법에서 더 나아가 pretrain된 model과 합성곱층, 데이터 분류기의 가중치를 update해 train 하는 방식
- 특성 추출은 목표 특성을 잘 추출했다는 전제하에 좋은 성능을 낼 수 있다.
- pretrain된 model을 목적에 맞게 재학습시키거나 학습된 가중치의 일부를 재학습시키는 것
- 데이터셋이 크고 pretrain된 모델과 유사성이 작은 경우
- 모델 전체를 재학습. 데이터셋 크기가 크기 때문에 재학습이 good
- 데이터셋이 크고 pretrain된 모델과 유사성이 큰 경우
- 합성곱층의 뒷부분과 데이터 분류기를 학습. 데이터셋이 유사하기 때문에 전체를 학습시키는 것보다 합성곱층의 뒷부분과 데이터 분류기만 새로 학스바는 것이 good
- 데이터셋이 작고 pretrain된 모델과 유사성이 작은 경우
- 합성곱층의 빌부분과 데이터 분류기를 학습. 데이터가 적기 때문에 일부 계층에 미세 조정 기법을 적용해도 효과가 없을 수 있다. 따라서 합성곱층 중 어디까지 새로 학습시켜야 할지 적당히 설정
- 데이터셋이 작고 pretrain된 모델과 유사성이 큰 경우
- 데이터 분류기만 학습. 데이터가 적기 때문에 많은 계층에 미세 조정 기법을 적용하면 overfitting 발생 가능. 최종 데이터 분류기인 fully connected layer에 대해서만 미세 조정 기법 적용하는 것이 good
5.4 설명 가능한 CNN(explainable CNN)¶
- 딥러닝 처리 결과를 사람이 이해할 수 있는 방식으로 제시하는 기술
- CNN은 블랙박스와 같아 내부에서 어떻게 동작하는지 설명하기 어렵다.
- CNN을 구성하는 각 중간 계층부터 최종 분류까지 input image의 feature가 어떻게 추출되고 학습하는지를 시각적으로 설명해야 신뢰가 올라감
- filter에 대한 시각화, CNN의 시각화가 존재
5.4.1 특성 맵 시각화¶
- 특성 맵(feature map)은 input image 또는 다른 feature map처럼 filter를 input에 적용한 결과
- 특정 input image에 대한 feature map을 시각화한다는 의미는 feature map에서 input feature를 감지하는 방법을 이해할 수 있도록 돕는 것
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.transforms import ToTensor
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
# 13개의 합성곱층과 두 개의 fully connected layer로 구성된 network 생성
class XAI(torch.nn.Module):
def __init__(self, num_classes=2):
super(XAI, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Conv2d(64, 64, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Conv2d(128, 128, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Conv2d(256, 256, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Conv2d(256, 256, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Conv2d(512, 512, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Conv2d(512, 512, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(512, 512, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Conv2d(512, 512, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Conv2d(512, 512, kernel_size=3, padding = 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(512, 512, bias=False),
nn.Dropout(0.5),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(-1, 512)
x = self.classifier(x)
return F.log_softmax(x, -1) # log_softmax는 softmax에 log 값을 취한 연산
# 이 함수를 통해 gradient descent 문제를 어느정도 해결 가능
model=XAI()
model.to(device)
model.eval()
XAI(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.3, inplace=False)
(4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU(inplace=True)
(11): Dropout(p=0.4, inplace=False)
(12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(13): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(14): ReLU(inplace=True)
(15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(16): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(17): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(18): ReLU(inplace=True)
(19): Dropout(p=0.4, inplace=False)
(20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(22): ReLU(inplace=True)
(23): Dropout(p=0.4, inplace=False)
(24): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(25): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(26): ReLU(inplace=True)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(29): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(30): ReLU(inplace=True)
(31): Dropout(p=0.4, inplace=False)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(33): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(34): ReLU(inplace=True)
(35): Dropout(p=0.4, inplace=False)
(36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(38): ReLU(inplace=True)
(39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(42): ReLU(inplace=True)
(43): Dropout(p=0.4, inplace=False)
(44): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(45): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(46): ReLU(inplace=True)
(47): Dropout(p=0.4, inplace=False)
(48): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(49): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(50): ReLU(inplace=True)
(51): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=512, out_features=512, bias=False)
(1): Dropout(p=0.5, inplace=False)
(2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
(4): Dropout(p=0.5, inplace=False)
(5): Linear(in_features=512, out_features=2, bias=True)
)
)
# layer의 feature map을 확인하기 위한 class
# Pytorch는 각 layer에 print를 하지 않더라도 hook 기능을 사용하여 각 계층의 activate function과
# gradient를 확인 가능
class LayerActivations:
features=[]
def __init__(self, model, layer_num):
self.hook = model[layer_num].register_forward_hook(self.hook_fn)
# register_forward_hook => forward 중에 각 network module의 input 및 output을 가져옴
def hook_fn(self, module, input, output):
self.features = output.detach().numpy()
def remove(self):
self.hook.remove()
import torch
x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
z = x + y
o = w.matmul(z)
print(o)
o.backward()
print(x.grad, y.grad, z.grad, w.grad, o.grad)
tensor(80., grad_fn=<DotBackward>)
tensor([1., 2., 3., 4.]) tensor([1., 2., 3., 4.]) None tensor([ 4., 6., 8., 10.]) None
C:\Users\user\AppData\Local\Temp\ipykernel_11252\377187508.py:10: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
print(x.grad, y.grad, z.grad, w.grad, o.grad)
img=cv2.imread("./data/cat.jpg")
plt.imshow(img)
img = cv2.resize(img, (100, 100), interpolation=cv2.INTER_LINEAR) # interpolation => 보간법
img = ToTensor()(img).unsqueeze(0)
print(img.shape)
torch.Size([1, 3, 100, 100])
result = LayerActivations(model.features, 0)
img = img.to('cpu')
model.to('cpu')
model(img)
activations = result.features
# input layer와 근처 Conv2d 계층 결과
# input image의 형태를 많이 유지하는 모습
fig, axes = plt.subplots(4,4)
fig = plt.figure(figsize=(12, 8))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for row in range(4):
for column in range(4):
axis = axes[row][column]
axis.get_xaxis().set_ticks([])
axis.get_yaxis().set_ticks([])
axis.imshow(activations[0][row*10+column])
plt.show()
<Figure size 1200x800 with 0 Axes>
result = LayerActivations(model.features, 20)
model(img)
activations = result.features
# 20번째 Conv2d feature map 확인
# input layer의 근처보다 input image의 모습을 잘 유지하지 못함
fig, axes = plt.subplots(4,4)
fig = plt.figure(figsize=(12, 8))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for row in range(4):
for column in range(4):
axis = axes[row][column]
axis.get_xaxis().set_ticks([])
axis.get_yaxis().set_ticks([])
axis.imshow(activations[0][row*10+column])
plt.show()
<Figure size 1200x800 with 0 Axes>
result = LayerActivations(model.features, 40)
model(img)
activations = result.features
# 40번째 layer확인
# 더 input image의 모습이 사라짐
fig, axes = plt.subplots(4,4)
fig = plt.figure(figsize=(12, 8))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for row in range(4):
for column in range(4):
axis = axes[row][column]
axis.get_xaxis().set_ticks([])
axis.get_yaxis().set_ticks([])
axis.imshow(activations[0][row*10+column])
plt.show()
<Figure size 1200x800 with 0 Axes>
!jt -r
Reset css and font defaults in:
C:\Users\user\.jupyter\custom &
C:\Users\user\AppData\Roaming\jupyter\nbextensions