반응형
파이토치를 처음으로 써서 여러가지 문제를 겪었는데
다음에 또 같은실수를 반복하지 않기위해 정리해본다
[저처럼 처음이신분들만 이해가능한 오류]
제일먼저 CrossEntropy같은경우에는 마지막 레이어 노드수가 2개 이상이여야 한다.
1개일 경우에는 사용이 안됨.
마지막 레이어가 노드수가 1개일 경우에는 보통 Binary Classification을 할때 사용될수가 있는데
이럴경우 BCELoss를 사용할때가 있다.
BCELoss함수를 사용할 경우에는 먼저 마지막 레이어의 값이 0~1로 조정을 해줘야하기 때문에
단순하게 마지막 레이어를 nn.Linear로 할때 out of range 에러가 뜬다.
따라서 BCELoss함수를 쓸땐 마지막 레이어를 시그모이드함수를 적용시켜줘야 한다.
두가지 경우의 예제를 간단하게 올려보면:
1. CrossEntropy 사용시[여기서는 MultiLabelSoftMarginLoss사용했지만 CrossEntropy도 동일한 방법으로 쓰면 될꺼같다, 다만 마지막 레이어의 softmax는 빼줘야할것]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = nn.Linear(101, 400)
self.dout = nn.Dropout(0.2)
self.l2 = nn.Linear(400, 400)
self.dout1 = nn.Dropout(0.2)
self.l3 = nn.Linear(400, 200)
self.l4 = nn.Linear(200, 2)
def forward(self, x):
x = F.relu(self.l1(x))
x = self.dout(x)
x = F.relu(self.l2(x))
x = self.dout1(x)
x = F.relu(self.l3(x))
return F.softmax(self.l4(x))
model = Net()
criterion = nn.MultiLabelSoftMarginLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
2. BCELoss 사용시
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(4909, 3000)
self.dout = nn.Dropout(0.2)
self.fc2 = nn.Linear(3000, 1500)
self.dout2 = nn.Dropout(0.2)
self.fc3 = nn.Linear(1500, 700)
self.dout3 = nn.Dropout(0.2)
self.fc4 = nn.Linear(700, 300)
self.fc5 = nn.Linear(300, 1)
def forward(self, input_):
a1 = F.relu(self.fc1(input_))
dout1 = self.dout(a1)
a2 = F.relu(self.fc2(dout1))
dout2 = self.dout2(a2)
a3 = F.relu(self.fc3(dout2))
dout3 = self.dout3(a3)
a4 = F.relu(self.fc4(dout3))
return torch.sigmoid(self.fc5(a4))
model = Net()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
반응형
'Data > Data Science' 카테고리의 다른 글
[Pyspark] Pyspark dataframe isin 과 is not in 방법 (0) | 2019.02.18 |
---|---|
[SQL] Coalesce 함수를 이용한 NULL값 처리 (0) | 2019.01.21 |
[Pytorch] MNIST CNN 코드 작성 & 공부 (0) | 2018.10.08 |
[Pytorch] MNIST DNN 코드 작성 & 공부 (0) | 2018.10.04 |
[RNN]Recurrent Neural Networks (0) | 2018.03.08 |