끵뀐꿩긘의 여러가지

GAN - 논문 리뷰 본문

Naver boostcamp -ai tech/Paper review

GAN - 논문 리뷰

끵뀐꿩긘 2022. 10. 19. 08:24

Contents

    논문

    Generative Adversarial Nets(2014, Ian J. Goodfellow)

    https://arxiv.org/abs/1406.2661

     

    Generative Adversarial Networks

    We propose a new framework for estimating generative models via an adversarial process, in which we simultaneously train two models: a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that

    arxiv.org


    0. Abstract

    적대적 과정(adversarial process)을 통해 generative model을 추정하는 새로운 프레임 워크 제안

    두가지 G(generative) model과 D(discriminative) model 의 훈련

    D(discriminative) model:

    데이터 분포로 부터 현재 들어온 샘플 데이터가 실제 데이터인지 아니면 G로부터 생성된 데이터인지 판단

     

    G(generative) model:

    D model이 실수를 하도록(진짜 데이터와 생성된 데이터를 구분하지 못하도록) 데이터 생성,

    G model이 완벽히 학습되면(training data의 분포를 완벽히 학습하면) D모델은 실제 데이터와 생성된 데이터를 구분하지 못하므로 1/2의 확률로 데이터를 판별한다.

     

    --> minimax two player game

    (minimax two player game은 서로가 반대되는 목표를 가지고 경쟁하는 구조라고 받아들면 된다.)

     

    *minimax 알고리즘

    마치 1대1 턴제 게임처럼 상대방이 최적의 판단을 한다고 가정하고 계산하는 알고리즘

     

    G와 D가 multilayer perceptron으로 구성딘다면, 역전파를 통해 학습된다

    => 이전의 generative 분야에서 사용한 Markov chains 또는 unrolled approximate inference networks를 사용할 필요없음

     

    1. Introduction

    딥러닝이 발전하면서 모든 종류의 데이터(음성, 그림 등)으니 확률 분포를 나타내는 모델을 만들 수 있게 된다

    특히, 그러한 과정에서  고차원의 방대한 sensory data 입력을 class label로 맵핑하는 discriminative models이 크게 발전하였다. (well-behaved gradient, 역전파와 dropout등 여러 알고리즘 기반)

     

    * 이미지뿐만 아니라 다양한 종류의 데이터는 다차원 특징 공간의 한 점으로 표현 될 수 있다.

    => ex. 이는 이미지의 분포를 근사하는 모델을 학습할 수 있고, 통계적인 평균치가 존재할 수 있다는 의미이다.

    이미지의 다변수 확률 분포 표현

     

    반면, Deep generative models은 잘 발전하지 못함

    - maximum likelihood estimation(최대 가능도 방법)으로부터 생기는 다루기 힘든 확률적 계산 근사의 어려움

    - generative 영역에서의 piecewise linear units(DL unit들) 사용의 어려움

     

    => 새로 제안하는 generative 모델이 이런한 어려움들을 해결

     

    GAN의 핵심 컨셉:

    G model은 지폐를 복사하려는 counterfeiter,

    D model은 지폐가 진짜이지 위조된 것인지 알아내려는 police

    학습이 진행될수록 counterfeiter는 진짜 같은 지폐를 만들어내고, police는 더 완벽히 판별하게 된다.

    이러한 경쟁하는 과정의 반복은 counterfeiter가 판별할 수 없을만큼 완벽한 위조 지폐를 만들어낼 떄까지 이어진다.

     

    이 프레임워크은 다양한 학습 알고리즘들과 optimization을 사용할 수 있지만

    이 논문에서는 G,D model 모두 multilayer perceptron로 구성되어 있고,  G 모델은 random noise로부터 데이터를 생성한다 

     

    3.  Adversarial nets

    $P_{data}$: 실제 데이터의 분포

    $P_z$: noise 데이터의 분포

    $G(z;\theta _g)$: parameters $θ_g$로 이루어진 multilayer perceptron,$P_z$를 data space에 mapping

    $D(x; \theta _d)$: parameters $θ_d$로 이루어진 multilayer perceptron, 생성된 데이터인지 아닌지 판별

    G는 V(D,G)가 작아지도록 학습되고, D는 V(D,G)가 커지도록 학습된다.

     

    1. D는 V(D,G)가 커지도록 학습:

     

    - D가 원본 데이터에 대해서 1(원본 데이터)라고 판단하는 경우

    $$\mathbb{E}_{x \sim P_{data}(x)}[logD(x)] = \mathbb{E}_{x \sim P_{data}(x)}[log(1)] = 0$$

    위의 식과 같이 최대값 0을 내보내지만,

     

    - D가 원본 데이터에 대해서 0(가짜 데이터)라고 판단하는 경우

    $$\mathbb{E}_{x \sim P_{data}(x)}[logD(x)] = \mathbb{E}_{x \sim P_{data}(x)}[log(0)] = -\infty$$

    위의 식과 같이 $-\infty$가 나오게 되므로,

    ==> D는 원본 데이터에 대해 1(원본 데이터)라고 판단하도록 학습된다.

     

    - D가 fake 데이터에 대해서 1(원본 데이터)라고 판단하는 경우

    $$\mathbb{E}_{z \sim P_{data}(z)}[log(1-D(G(z)))] = \mathbb{E}_{z \sim P_{data}(z)}[log(1- 1)] = -\infty$$

    위의 식과 같이 $-\infty$가 나오게 되고,

     

    - D가 fake 데이터에 대해서 0(가짜 데이터)라고 판단하는 경우

    $$\mathbb{E}_{z \sim P_{data}(z)}[log(1-D(G(z)))] = \mathbb{E}_{z \sim P_{data}(z)}[log(1 -0)] = 0$$

    위의 식과 같이 최대값 0을 내보낸다.

    ==> D는 가짜 데이터에 대해 0(가짜 데이터)라고 판단하도록 학습된다.

     

    2. G는 V(D,G)가 작아지도록 학습:

    V(D,G)에서 G에 관련된 식은

    $$\mathbb{E}_{z \sim P_{data}(z)}[log(1-D(G(z)))]$$

    위의 식뿐이며, 이를 작아지게 하기 위해서는 D가 G(z)(생성된 데이터)를 진짜라고 생각하게 만들어야한다.

    ==> G는 D(G(z))가 1(진짜 데이터)로 판단하도록 학습된다.

     

     

    학습중에 inner loop에서 D를 최적화하는 것은 많은 계산을 필요로하고 overfittin을 초래할 수 있으므로,

    k step만큼 D를 최적화하고, 1 step만큼 G를 최적화하는 방식을 사용한다.

    => D가 optimal solution 근처에 있게하면서, G가 적당한 속도로 바뀌게 도와준다.

     

    학습 초기에는 G의 성능이 형편없으므로, D가 G가 만든 데이터들을 매우 잘 분리해낸다.

    이로 인해, $log(1-D(G(z)))$가 포화되어버린다.

    (포화된 상태(값이 극단으로 몰리는 상태)에서는 $log(1-D(G(z)))$의 값이 매우 작으므로 gradient가 잘 흐르지 않는다)

    G가 $log(1-D(G(z)))$을 minimize하는 것을  대신하여 G가 D(G(z))를 maximize하는 것으로 대체하면, 

    G의 성능이 좋지 않은 초반에 학습을 빠르게 할 수 있다.

     

    1. iteration마다 k step의 D 최적화(ascending gradient)를 진행하고, 1 step의 G 최적화(descending gradient)를 진행한다.

    2 . noise sample의 배치 m개와 실제 sample의 배치 m개로 gradient를 계산하여 D를 최적화한다

    3. noise sample의 배치 m개로 gradient를 구하여 G를 최적화한다.

    전체적으로는 이런 느낌

    GAN은 G와 D model이 동시에 학습된다.

    파란 선은 D model의 분포, 초록선은 G 모델의 분포,

    점선은 $P_{data}$ 실제 데이터의 분포, z는 random noise(여기서는 uniform distribution)를 의미한다.

     

    (a): 학습초기에는 G model에서 생성한 분포와 실제 데이터의 분포가 다르고, D model의 분류 성능도 나쁨

    (b): D model의 분류 성능이 안정화되어 real과 fake를 잘 구분한다.

    (c): 어느정도 D의 학습이 이루어지면 , D가 real과 fake를 잘 구분하므로 G(z)를 실제 분포와 비슷하게 학습되도록 도와주는 이정표가 되어준다.

    (d): 이 과정의 반복으로, G는 실제 분포와 완전히 비슷한 분포를 가지게 되고, 이젠 D가 이 둘을 구분할 수 없게 되어 확률이 1/2이 된다.

     

    4. Theoretical Results

    앞서 제시되었던 GAN의 minmax problem이 제대로 working을 한다면,

    minmax problem이 global minimum에서 unique solution을 가지고 그 solution으로 수렴한다는 사실이 증명되어야 한다.

     

    1. Global Optimality가 유일하게 존재하는가

     

    proposition(명제 -1): $D^*_G(x) = \frac{P_{data}(x)}{P_{data}(x) + P_g(x)}$

    => v(G,D)가 최대가 되는 $D_G(x)$ 값 : $D^*_G(x) = \frac{P_{data}(x)}{P_{data}(x) + P_g(x)}$

     

    proposition(명제 -2): Global optimum point is $p_g = p_{data}$

    ==> $D_G(x)$가 $D^*_G(x)$일 때, v(G,D)가 최소가 되는 G의 분포:  $p_g = p_{data}$

     

     

    2. 알고리즘이 수렴하는가

    수학적 개념을 아직 이해해지 못하겠다 sup(상한), subderivatives(하위미분) 등..

     

    => V(G,D)가 $p_g$에 대해 convex하므로 $p_g$를 조금씩 그러나 충분할 정도로 업데이트 하다보면 $p_g$는 $p_x$에 수렴하게 된다는 뜻이란다..

    + 현재는 옳지 않은 것으로 증명됨 (deep learning에서는 임계점이 여러개이다)

     

    위의 증명들과 같이 GAN은 이론적으로는 수렴하지만, 실제적으로는 $P_g$를 최적화하는 것이 아니라 $\theta _g$를 최적화하는 과정을 거치는 것이므로, 제한된 $P_g$ 분포군으로 인한 오류(학습이 안되는 경우가 생김)가 발생한다.

     

    5. Experiments

    MNIST, Toronto Face Database(TFD), CIFAR-10에 대해 학습 진행

    G는 rectifier linear activations, sigmoid 혼합하여 사용, D는 maxout activation, Dropout사용

    * maxout activation: ReLU의 일반화 활성화함수 각 뉴런에 최적화된 활성 함수를 학습을 통해 찾아낸다.

    generator 첫 layer에서만 noise를 사용

    G로 생성된 sample에 Gaussian Parzen window맞추고, 해당 분포에 따른 log-likelihood를 알려줌으로써 Pg에 따른 test set data 추정

     

    => G가 생성해낸 sample이 이전 방법보다 월등히 좋다고 할 수는 없지만, 다른 모델들과 겨룰만하다.(잠재력이 있다고 강조)

     

    6.  Advantages and disadvantages

    단점:

    • $P_g(x)$가 명시적으로 존재하지 않는다
    • D와 G가 균형을 잘 맞춰서 학습되어야한다.(D가 업데이트 되기 이전에 G가 많이 학습되어버리면 G가 z의 데이터를 너무 많이 붕괴시켜버린다
    • == 데이터 공간의 작은 영역에 너무 많은 G 확률 밀도가 배치된다. == model collapse(모델 붕괴)
    • == 데이터가 D를 속이기 위해서 비슷한 이미지만 생성한다. ex. 6이랑 비슷한 이미지만 생성)

    장점:

    • Markov chains이 전혀 필요 없고 gradients를 얻기 위해 back-propagation만이 사용됨
    • 학습 중 어떠한 inference가 필요 없음
    • 다양한 함수들이 모델이 접목될 수 있음
    • Markov chains을 쓸 때보다 훨씬 선명한 이미지를 얻을 수 있음
     
    맨 오른쪽 열 빼고는 모두 GAN으로 생성한 이미지이다.

    7.  Conclusions and future work

    • conditional generative model로 발전시킬 수 있음 (CGAN)
    • Learned approximate inference는 주어진 x를 예측하여 수행될 수 있음
    • parameters를 공유하는 conditionals model를 학습함으로써 다른 conditionals models을 근사적으로 모델링할 수 있음. 특히 MP-DBM의 stochastic extension의 구현 에 대부분의 네트워크를 사용할 수 있음
    • Semi-supervised learning: 제한된 레이블이 있는 데이터 사용할 수 있을 때, classifiers의 성능 향상시킬 수 있음
    • 효율성 개선: G,D를 조정하는 더 나은 방법이나 학습하는 동안 sample z에 대한 더 나은 분포를 결정함으로써 학습의 속도 높일 수 있음

     
     
     

    모델 구현 - MNIST

    import torch
    import torch.nn as nn
    import torchvision
    from torchvision import datasets, transforms
    from torchvision.utils import save_image

    모델 정의

    # Define Generator
    latent_dim = 100 # noise 차원
    
    class Generator(nn.Module):
      def __init__(self):
        super(Generator, self).__init__()
    
        # 하나의 블록 정의
        def block(input_dim, output_dim, normalize = True):
          layers = []
          layers.append(nn.Linear(input_dim, output_dim))
    
          if normalize:
            # 배치 정규화
            layers.append(nn.BatchNorm1d(output_dim, 0.8))
          layers.append(nn.LeakyReLU(0.2, inplace = True))
          return layers
    
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize = False), # 처음엔 왜 BN을 적용하지 않았나? 
            *block(128, 256),
            *block(256, 518),
            *block(518, 1024),
            nn.Linear(1024,1*28*28),
            nn.Tanh() # -1~1 사이의 값을 출력해준다, normalize 처리한 이미지 상정
        )
    
      def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1,28,28)
        return img
    # Defien Discriminator
    class Discriminator(nn.Module):
      def __init__(self):
        super(Discriminator, self).__init__()
    
        self.model = nn.Sequential(
            nn.Linear(784,512), # 입력은 이미지 크기
            nn.LeakyReLU(0.2, inplace = True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Linear(256,1),
            nn.Sigmoid() # 시그모이드로 0,1만 판별
        )
    
      def forward(self, img):
        flatten = img.view(img.size(0),-1)
        output = self.model(flatten)
        return output

    데이터셋 & 하이퍼 파라미터 설정

    # 데아터셋 불러오기
    transforms_train = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize([0.5],[0.5])
    ])
    
    train_dataset = datasets.MNIST(root = "./dataset", train = True, download = True, transform = transforms_train)
    dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 128, shuffle = True, num_workers = 4)
    # 모델 초기화 & 하이퍼 파라미터 설정
    generator = Generator().cuda()
    discriminator = Discriminator().cuda()
    
    adversarial_loss = nn.BCELoss().cuda() # 바이너리 Cross-Entropy() # 0인지 1인지만 판단
    
    lr = 0.0002
    
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999)) # G에 대한 optimizer
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) # D에 대한 optimizer
    # 텐서보드
    from torch.utils.tensorboard import SummaryWriter
    
    # 기본 `log_dir` 은 "runs"이며, 여기서는 더 구체적으로 지정하였습니다
    writer = SummaryWriter() # writer 객체 생성
    %load_ext tensorboard
    %tensorboard --logdir ./runs

    훈련 및 평가

    import matplotlib.pyplot as plt
    # 모델 학습
    import time
    
    n_epochs = 100 # epoch 설정
    sample_interval = 2000 # 몇번의 배치마다 결과를 출력한 것인지 설정
    start_time = time.time()
    
    for epoch in range(n_epochs):
      for i, (imgs, _) in enumerate(dataloader):
    
        # 정답 레이블 생성
        real = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(1.0) # 진짜(real): 1
        fake = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(0.0) # 가짜(fake): 0
    
        real_imgs = imgs.cuda()
    
        # 판별자 D 학습 - k번 반복
        for _ in range(2):
          z = torch.normal(mean=0, std=1, size=(imgs.shape[0], latent_dim)).cuda() # latent
    
          # 이미지 생성
          generated_imgs = generator(z)
          optimizer_D.zero_grad()
    
          # loss 계산
          real_loss = adversarial_loss(discriminator(real_imgs), real)
          fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
          d_loss = (real_loss + fake_loss) / 2
    
          # 판별자(discriminator) 업데이트
          d_loss.backward()
          optimizer_D.step()
    
        # 생성자 G 학습
        optimizer_G.zero_grad()
    
        z = torch.normal(mean=0, std=1, size=(imgs.shape[0], latent_dim)).cuda() # latent
        # 이미지 생성
        generated_imgs = generator(z)
    
        # 생성자(generator)의 손실(loss) 값 계산
        g_loss = adversarial_loss(discriminator(generated_imgs), real)
    
        # 생성자(generator) 업데이트
        g_loss.backward()
        optimizer_G.step()
    
        done = epoch * len(dataloader) + i
        if done % sample_interval == 0:
          imgs = generated_imgs.data[:25]
          save_image(imgs, f"{done}.png", nrow=5, normalize=True)
          writer.add_scalar("D loss", d_loss.item(), done)
          writer.add_scalar("G loss", g_loss.item(), done)
          img_grid = torchvision.utils.make_grid(imgs, normalize = True, nrow = 5)
          writer.add_image('generate img',img_grid, done)
    
          print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}] [Elapsed time: {time.time() - start_time:.2f}s]")
    
    writer.close()

    학습: 결과

    [Epoch 0/100] [D loss: 0.307546] [G loss: 1.636394] [Elapsed time: 0.36s]
    [Epoch 4/100] [D loss: 0.274413] [G loss: 1.734445] [Elapsed time: 66.59s]
    [Epoch 8/100] [D loss: 0.369117] [G loss: 1.932671] [Elapsed time: 133.34s]
    [Epoch 12/100] [D loss: 0.324094] [G loss: 1.615475] [Elapsed time: 198.76s]
    [Epoch 17/100] [D loss: 0.370909] [G loss: 1.750932] [Elapsed time: 265.40s]
    [Epoch 21/100] [D loss: 0.340882] [G loss: 1.804203] [Elapsed time: 330.37s]
    [Epoch 25/100] [D loss: 0.357644] [G loss: 1.634121] [Elapsed time: 396.51s]
    [Epoch 29/100] [D loss: 0.370007] [G loss: 1.578714] [Elapsed time: 461.47s]
    [Epoch 34/100] [D loss: 0.355696] [G loss: 2.019880] [Elapsed time: 527.81s]
    [Epoch 38/100] [D loss: 0.332385] [G loss: 1.880666] [Elapsed time: 592.50s]
    [Epoch 42/100] [D loss: 0.326968] [G loss: 1.743924] [Elapsed time: 657.95s]
    [Epoch 46/100] [D loss: 0.329264] [G loss: 1.478282] [Elapsed time: 722.18s]
    [Epoch 51/100] [D loss: 0.295945] [G loss: 1.718297] [Elapsed time: 787.91s]
    [Epoch 55/100] [D loss: 0.337251] [G loss: 1.600780] [Elapsed time: 853.72s]
    [Epoch 59/100] [D loss: 0.317490] [G loss: 1.776057] [Elapsed time: 918.59s]
    [Epoch 63/100] [D loss: 0.306110] [G loss: 1.730978] [Elapsed time: 983.94s]
    [Epoch 68/100] [D loss: 0.313961] [G loss: 1.644931] [Elapsed time: 1048.82s]
    [Epoch 72/100] [D loss: 0.291106] [G loss: 1.776134] [Elapsed time: 1113.78s]
    [Epoch 76/100] [D loss: 0.289232] [G loss: 1.865839] [Elapsed time: 1177.25s]
    [Epoch 81/100] [D loss: 0.299694] [G loss: 1.979094] [Elapsed time: 1242.24s]
    [Epoch 85/100] [D loss: 0.319803] [G loss: 2.103588] [Elapsed time: 1305.89s]
    [Epoch 89/100] [D loss: 0.310971] [G loss: 2.008342] [Elapsed time: 1370.43s]
    [Epoch 93/100] [D loss: 0.351985] [G loss: 1.899914] [Elapsed time: 1434.00s]
    [Epoch 98/100] [D loss: 0.322940] [G loss: 2.039061] [Elapsed time: 1498.74s]

    생성된 데이터 예시

     

    Comments