본문 바로가기

Data/Data Science

[Pytorch] Autoencoder Base code

목차

    반응형

    import torch
    import torchvision
    import torch.nn.functional as F
    from torch import nn, optim
    from torchvision import transforms, datasets
    
    class Autoencoder(nn.Module):
        def __init__(self):
            super(Autoencoder, self).__init__()
            
            self.encoder = nn.Sequential(
                nn.Linear(28*28, 128),
                nn.ReLU(),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 12),
                nn.ReLU(),
                nn.Linear(12, 3),
                        )
            
            self.decoder = nn.Sequential(
                nn.Linear(3, 12),
                nn.ReLU(),
                nn.Linear(12, 64),
                nn.ReLU(),
                nn.Linear(64, 128),
                nn.ReLU(),
                nn.Linear(128, 28*28),
                nn.Sigmoid(),
            )
            
        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return encoded, decoded
    
    autoencoder = Autoencoder()
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr = 0.005)
    criterion = nn.MSELoss()

    학습진행시 input과 output을 동일한 대상으로 넣어줘서 학습을 진행

    반응형