GAN(Generative Adversarial Network) 기초


최근 핫한 Generative Adversarial Network, GAN을 학습한 내용을 정리한 문서입니다

Introduction

Supervised Learning

  • Discriminative Model : Input이 들어갔을 때 Output이 나옴

Unsupervised Learning

  • Generative Model : train 데이터 분포를 학습

  • 파란색이 실제 데이터의 분포고 빨간색은 실제 데이터의 분포를 근사! 두개의 차이값을 줄이는 것이 목표

Generative Adversarial Network

  • Discriminator를 먼저 학습 : 진짜 이미지가 들어가면 진짜로 구분, 가짜 이미지는 가짜로 구분
    • input : 이미지의 고정된 벡터
    • output : 진짜 / 가짜 : 1 (sigmoid를 통해 0.5 기준으로 classification)
  • generator는 랜덤한 코드를 받아서 이미지를 생성 -> 그리고 discriminator를 속여야 함(1이 나오도록 학습)

discriminator object 함수(loss)

  • discriminator는 목적함수를 최대화하는 것이 목표
  • x~p_data(x) : 확률 밀도함수
  • z : 랜덤한 벡터(표준 정규분포나 100차원 벡터로 샘플링)을 g에게 줬을 경우 이미지를 생성
  • 처음에 형편없는 이미지를 만들기 때문에 discriminator는 가짜라고 확신을 함. 그러면 D(G(z))가 0에 가깝게 나옴(= 기울기의 절대값이 엄청 작음) 기울기를 크게 하기 위해 log(x)그래프를 max!!! => 기울기가 무한대
  • 초반에 generator가 굉장히 안좋은 상황에서discriminator가 가짜라고 확신하는 상황을 빨리 벗어나기 위해 이런 트릭을 사용

generator 목적 함수

  • generator는 우측의 식을 최소화

코드

import torch
import torch.nn. as nn

D = nn.Sequential(
	nn.Linear(784 ,128),
	nn.ReLU(),
	nn.Linear(128, 1),
	nn.Sigmoid())
	
G = nn.Sequential(
	nn.Linear(100, 128),
	nn.ReLU(),
	nn.Linear(128, 784),
	nn.Tanh()) # 생성된 값이 -1 ~ 1
	
criterion = nn.BCELoss() # Binary Cross Entropy Loss(h(x), y), Sigmoid Cross Entropy Loss 함수라고도 불림. -ylogh(x)-(1-y)log(1-h(x))

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.01)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.01)
# 충돌하기에 2개의 optimizer를 설정

while True:
	# train D
	loss = criterion(D(x), 1) + criterion(D(G(z)), 0)
	loss.backward() # 모든 weight에 대해 gradient값을 계산
	d_optimizer.step()
	
	# train G
	loss = criterion(D(G(z)), 1)
	loss.backward()
	g_optimizer.step() # generator의 파라미터를 학습

왜 이론적으로 잘 되는가?

  • 최적화하는 것이 서로 다른 확률분포간의 차이를 줄여주기 때문에 실제 Generator가 실제와 가까운 이미지를 만들 수 있다!

Variants of GAN

Deep Convolutional GAN (DCGAN, 2015)

  • CNN을 사용해서 discriminator를 생성하고 deconvolutional network를 통해서 generator를 만든 모델
  • 아직까지도 가장 선호되면서 간단히 만들 수 있는 모델
  • 핵심 : Pooling Layer를 사용하지 않음! (사용하면 unpooling할 때 blocking한 이미지를 생성) 대신 Stride가 2 이상인 convolution과 deconvolution을 사용함
  • Adam Optimizer의 모멘텀 텀이 0.5, 0.999 2개가 있음. 64x64를 생성할 때 저 파라미터를 사용하면 성능이 좋아서 이렇게 사용중
  • Generator에 들어가는 latent vector를 통해 연산을 할 수 있음(Word2vec같이!) : Z vector가 선형적 관계를 가짐

Least Squares GAN (LSGAN)

  • 기존엔 discriminator를 속이기만 하면 됨.
  • discirminator를 완벽히 속인 친구들. 좋진 않음! 우리의 목적은 진짜 데이터와 가까워야 함
D = nn.Sequential(
	nn.Linear(784 ,128),
	nn.ReLU(),
	nn.Linear(128, 1)) # sigmoid를 없앰
	
G = nn.Sequential(
	nn.Linear(100, 128),
	nn.ReLU(),
	nn.Linear(128, 784),
	nn.Tanh()) 
	
D_loss = torch.mean((D(x) -1)**2) + torch.mean(D(G(z))**2)

G_loss = torch.mean((D(G(z))-1)**2)
# cross entropy loss와 차이 : decision boundary가 1에 가깝도록 만듬	

Semi-Supervised GAN

  • discriminator가 진짜, 가짜를 구분하지 않고 클래스를 구분하게 됨. 기존 10개의 클래스 + fake
  • 위쪽은 discriminator쪽은 Supervised Learning, generator는 Unsupervised Learning

Auxiliary Classifier GAN(ACGAN, 2016)

  • discriminator가 하는 일이 2가지
    • 진짜를 구분 (sigmoid)
    • 진짜든 가짜든 클래스를 구분 (softmax)
    • Multi-task learning
  • generator도 하는 일이 2가지
    • 진짜, 가짜를 구분
    • 클래스 구분
  • 여태는 generator에 집중했던 GAN들이 많다면 discriminator에 집중!

Extensions of GAN

CycleGAN : Unpaired Image to Image Translation

  • 얼룩말을 말로 바꾸고 여름을 겨울로 바꾸는 이미지의 스타일을 바꾸는 것
  • pair example이 없이 unsupervised learning을 통해 모델을 학습
  • generator는 latent vector를 받지 않고 이미지를 input으로 받음. 인코더-디코더같은 느낌
  • discriminator는 얼룩말 이미지를 말 이미지로 바꾸고 싶음. 말 이미지를 주고 이 이미지가 진짜다라고 학습을 하고 generator는 얼룩말을 받고 말로 바꿈
  • 얼룩말 사진을 주고 말을 뛰는 모습을 생성하면 속을 수 있음! style transfer는 모양을 유지
  • 얼룩말 이미지를 다시 말로 바꿈!!

StackGAN

  • 텍스트를 주고 텍스트에 해당하는 이미지를 만듬
  • Task에 집중하지 않고
  • 한번에 고해상도를 학습할 경우 힘들다는 한계가 존재함. 64x64를 먼저 만들고 업샘플링 과정을 겪음

Visual Attribute Transfer

  • CycleGAN처럼 Input이 왼쪽, 오른쪽 2개로 들어감

User-Interactive Image Colorization

  • 흑백 사진을 칼라로 변환해줌(사용자가 원하는 색으로 변경 가능)

Future of GAN

Boundary Equilibrium GAN (BEGAN)

  • 해당 로스함수가 줄어들 때 학습이 잘되더라(하지만 이것은 휴리스틱하게 나온 결과)
  • discriminator auto encoder 구조라 복잡한 편

Reconstruction Loss

  • weight normalization을 사용
  • 단점 : 로스를 얻기 위해 z값을 학습해야함

Deconvolution Checkboard Artifacts

  • 좋은 upsampling을 찾아야 하지 않을까-
  • Deconvolution을 많이 사용했었는데 이건 output을 불균형하게 생김(체크보드 패턴처럼 이미지가 생김)
  • Resize-Convolution을 사용하면 up sampling은 룰베이스 방식으로 한 후, convolution stride=1을 하고 필터링을 여러 레이어로 쌓게 됨! - 골고루 고려할 수 있음
  • BEGAN에선 resize-convlution을 사용했음

Machine Translation (Seq2Seq)

  • Supervised learning은 영어로 주면 한글의 한 문장을 나옴. 이게 Supervised learning의 한계라 생각하고 GAN을 활용

Reference


카일스쿨 유튜브 채널을 만들었습니다. 데이터 사이언스, 성장, 리더십, BigQuery 등을 이야기할 예정이니, 관심 있으시면 구독 부탁드립니다 :)

PM을 위한 데이터 리터러시 강의를 만들었습니다. 문제 정의, 지표, 실험 설계, 문화 만들기, 로그 설계, 회고 등을 담은 강의입니다

이 글이 도움이 되셨거나 다양한 의견이 있다면 댓글 부탁드립니다 :)

Buy me a coffeeBuy me a coffee





© 2017. by Seongyun Byeon

Powered by zzsza