GAN(Generative Adversarial Network) 기초
in Data on Deep-Learning
최근 핫한 Generative Adversarial Network, GAN을 학습한 내용을 정리한 문서입니다
Introduction
Supervised Learning
- Discriminative Model : Input이 들어갔을 때 Output이 나옴
Unsupervised Learning
- Generative Model : train 데이터 분포를 학습
- 파란색이 실제 데이터의 분포고 빨간색은 실제 데이터의 분포를 근사! 두개의 차이값을 줄이는 것이 목표
Generative Adversarial Network
- Discriminator를 먼저 학습 : 진짜 이미지가 들어가면 진짜로 구분, 가짜 이미지는 가짜로 구분
- input : 이미지의 고정된 벡터
- output : 진짜 / 가짜 : 1 (sigmoid를 통해 0.5 기준으로 classification)
- generator는 랜덤한 코드를 받아서 이미지를 생성 -> 그리고 discriminator를 속여야 함(1이 나오도록 학습)
discriminator object 함수(loss)
- discriminator는 목적함수를 최대화하는 것이 목표
x~p_data(x)
: 확률 밀도함수z
: 랜덤한 벡터(표준 정규분포나 100차원 벡터로 샘플링)을 g에게 줬을 경우 이미지를 생성- 처음에 형편없는 이미지를 만들기 때문에 discriminator는 가짜라고 확신을 함. 그러면 D(G(z))가 0에 가깝게 나옴(= 기울기의 절대값이 엄청 작음) 기울기를 크게 하기 위해 log(x)그래프를 max!!! => 기울기가 무한대
- 초반에 generator가 굉장히 안좋은 상황에서discriminator가 가짜라고 확신하는 상황을 빨리 벗어나기 위해 이런 트릭을 사용
generator 목적 함수
- generator는 우측의 식을 최소화
코드
import torch
import torch.nn. as nn
D = nn.Sequential(
nn.Linear(784 ,128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid())
G = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 784),
nn.Tanh()) # 생성된 값이 -1 ~ 1
criterion = nn.BCELoss() # Binary Cross Entropy Loss(h(x), y), Sigmoid Cross Entropy Loss 함수라고도 불림. -ylogh(x)-(1-y)log(1-h(x))
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.01)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.01)
# 충돌하기에 2개의 optimizer를 설정
while True:
# train D
loss = criterion(D(x), 1) + criterion(D(G(z)), 0)
loss.backward() # 모든 weight에 대해 gradient값을 계산
d_optimizer.step()
# train G
loss = criterion(D(G(z)), 1)
loss.backward()
g_optimizer.step() # generator의 파라미터를 학습
왜 이론적으로 잘 되는가?
- 최적화하는 것이 서로 다른 확률분포간의 차이를 줄여주기 때문에 실제 Generator가 실제와 가까운 이미지를 만들 수 있다!
Variants of GAN
Deep Convolutional GAN (DCGAN, 2015)
- CNN을 사용해서 discriminator를 생성하고 deconvolutional network를 통해서 generator를 만든 모델
- 아직까지도 가장 선호되면서 간단히 만들 수 있는 모델
- 핵심 : Pooling Layer를 사용하지 않음! (사용하면 unpooling할 때 blocking한 이미지를 생성) 대신 Stride가 2 이상인 convolution과 deconvolution을 사용함
- Adam Optimizer의 모멘텀 텀이 0.5, 0.999 2개가 있음. 64x64를 생성할 때 저 파라미터를 사용하면 성능이 좋아서 이렇게 사용중
- Generator에 들어가는 latent vector를 통해 연산을 할 수 있음(Word2vec같이!) : Z vector가 선형적 관계를 가짐
Least Squares GAN (LSGAN)
- 기존엔 discriminator를 속이기만 하면 됨.
- discirminator를 완벽히 속인 친구들. 좋진 않음! 우리의 목적은 진짜 데이터와 가까워야 함
D = nn.Sequential(
nn.Linear(784 ,128),
nn.ReLU(),
nn.Linear(128, 1)) # sigmoid를 없앰
G = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 784),
nn.Tanh())
D_loss = torch.mean((D(x) -1)**2) + torch.mean(D(G(z))**2)
G_loss = torch.mean((D(G(z))-1)**2)
# cross entropy loss와 차이 : decision boundary가 1에 가깝도록 만듬
Semi-Supervised GAN
- discriminator가 진짜, 가짜를 구분하지 않고 클래스를 구분하게 됨. 기존 10개의 클래스 + fake
- 위쪽은 discriminator쪽은 Supervised Learning, generator는 Unsupervised Learning
Auxiliary Classifier GAN(ACGAN, 2016)
- discriminator가 하는 일이 2가지
- 진짜를 구분 (sigmoid)
- 진짜든 가짜든 클래스를 구분 (softmax)
- Multi-task learning
- generator도 하는 일이 2가지
- 진짜, 가짜를 구분
- 클래스 구분
- 여태는 generator에 집중했던 GAN들이 많다면 discriminator에 집중!
Extensions of GAN
CycleGAN : Unpaired Image to Image Translation
- 얼룩말을 말로 바꾸고 여름을 겨울로 바꾸는 이미지의 스타일을 바꾸는 것
- pair example이 없이 unsupervised learning을 통해 모델을 학습
- generator는 latent vector를 받지 않고 이미지를 input으로 받음. 인코더-디코더같은 느낌
- discriminator는 얼룩말 이미지를 말 이미지로 바꾸고 싶음. 말 이미지를 주고 이 이미지가 진짜다라고 학습을 하고 generator는 얼룩말을 받고 말로 바꿈
- 얼룩말 사진을 주고 말을 뛰는 모습을 생성하면 속을 수 있음! style transfer는 모양을 유지
- 얼룩말 이미지를 다시 말로 바꿈!!
StackGAN
- 텍스트를 주고 텍스트에 해당하는 이미지를 만듬
- Task에 집중하지 않고
- 한번에 고해상도를 학습할 경우 힘들다는 한계가 존재함. 64x64를 먼저 만들고 업샘플링 과정을 겪음
Visual Attribute Transfer
- CycleGAN처럼 Input이 왼쪽, 오른쪽 2개로 들어감
User-Interactive Image Colorization
- 흑백 사진을 칼라로 변환해줌(사용자가 원하는 색으로 변경 가능)
Future of GAN
Boundary Equilibrium GAN (BEGAN)
- 해당 로스함수가 줄어들 때 학습이 잘되더라(하지만 이것은 휴리스틱하게 나온 결과)
- discriminator auto encoder 구조라 복잡한 편
Reconstruction Loss
- weight normalization을 사용
- 단점 : 로스를 얻기 위해 z값을 학습해야함
Deconvolution Checkboard Artifacts
- 좋은 upsampling을 찾아야 하지 않을까-
- Deconvolution을 많이 사용했었는데 이건 output을 불균형하게 생김(체크보드 패턴처럼 이미지가 생김)
- Resize-Convolution을 사용하면 up sampling은 룰베이스 방식으로 한 후, convolution stride=1을 하고 필터링을 여러 레이어로 쌓게 됨! - 골고루 고려할 수 있음
- BEGAN에선 resize-convlution을 사용했음
Machine Translation (Seq2Seq)
- Supervised learning은 영어로 주면 한글의 한 문장을 나옴. 이게 Supervised learning의 한계라 생각하고 GAN을 활용
Reference
카일스쿨 유튜브 채널을 만들었습니다. 데이터 사이언스, 성장, 리더십, BigQuery 등을 이야기할 예정이니, 관심 있으시면 구독 부탁드립니다 :)
PM을 위한 데이터 리터러시 강의를 만들었습니다. 문제 정의, 지표, 실험 설계, 문화 만들기, 로그 설계, 회고 등을 담은 강의입니다
이 글이 도움이 되셨거나 다양한 의견이 있다면 댓글 부탁드립니다 :)