GAN(Generative Adversarial Network) 기초
in Data on Deep-Learning
최근 핫한 Generative Adversarial Network, GAN을 학습한 내용을 정리한 문서입니다
Introduction
Supervised Learning
- Discriminative Model : Input이 들어갔을 때 Output이 나옴
Unsupervised Learning
- Generative Model : train 데이터 분포를 학습
- 파란색이 실제 데이터의 분포고 빨간색은 실제 데이터의 분포를 근사! 두개의 차이값을 줄이는 것이 목표
Generative Adversarial Network
- Discriminator를 먼저 학습 : 진짜 이미지가 들어가면 진짜로 구분, 가짜 이미지는 가짜로 구분
- input : 이미지의 고정된 벡터
- output : 진짜 / 가짜 : 1 (sigmoid를 통해 0.5 기준으로 classification)
- generator는 랜덤한 코드를 받아서 이미지를 생성 -> 그리고 discriminator를 속여야 함(1이 나오도록 학습)
discriminator object 함수(loss)
- discriminator는 목적함수를 최대화하는 것이 목표
x~p_data(x)
: 확률 밀도함수z
: 랜덤한 벡터(표준 정규분포나 100차원 벡터로 샘플링)을 g에게 줬을 경우 이미지를 생성- 처음에 형편없는 이미지를 만들기 때문에 discriminator는 가짜라고 확신을 함. 그러면 D(G(z))가 0에 가깝게 나옴(= 기울기의 절대값이 엄청 작음) 기울기를 크게 하기 위해 log(x)그래프를 max!!! => 기울기가 무한대
- 초반에 generator가 굉장히 안좋은 상황에서discriminator가 가짜라고 확신하는 상황을 빨리 벗어나기 위해 이런 트릭을 사용
generator 목적 함수
- generator는 우측의 식을 최소화
코드
import torch
import torch.nn. as nn
D = nn.Sequential(
nn.Linear(784 ,128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid())
G = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 784),
nn.Tanh()) # 생성된 값이 -1 ~ 1
criterion = nn.BCELoss() # Binary Cross Entropy Loss(h(x), y), Sigmoid Cross Entropy Loss 함수라고도 불림. -ylogh(x)-(1-y)log(1-h(x))
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.01)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.01)
# 충돌하기에 2개의 optimizer를 설정
while True:
# train D
loss = criterion(D(x), 1) + criterion(D(G(z)), 0)
loss.backward() # 모든 weight에 대해 gradient값을 계산
d_optimizer.step()
# train G
loss = criterion(D(G(z)), 1)
loss.backward()
g_optimizer.step() # generator의 파라미터를 학습
왜 이론적으로 잘 되는가?
- 최적화하는 것이 서로 다른 확률분포간의 차이를 줄여주기 때문에 실제 Generator가 실제와 가까운 이미지를 만들 수 있다!
Variants of GAN
Deep Convolutional GAN (DCGAN, 2015)
- CNN을 사용해서 discriminator를 생성하고 deconvolutional network를 통해서 generator를 만든 모델
- 아직까지도 가장 선호되면서 간단히 만들 수 있는 모델
- 핵심 : Pooling Layer를 사용하지 않음! (사용하면 unpooling할 때 blocking한 이미지를 생성) 대신 Stride가 2 이상인 convolution과 deconvolution을 사용함
- Adam Optimizer의 모멘텀 텀이 0.5, 0.999 2개가 있음. 64x64를 생성할 때 저 파라미터를 사용하면 성능이 좋아서 이렇게 사용중
- Generator에 들어가는 latent vector를 통해 연산을 할 수 있음(Word2vec같이!) : Z vector가 선형적 관계를 가짐
Least Squares GAN (LSGAN)
- 기존엔 discriminator를 속이기만 하면 됨.
- discirminator를 완벽히 속인 친구들. 좋진 않음! 우리의 목적은 진짜 데이터와 가까워야 함
D = nn.Sequential(
nn.Linear(784 ,128),
nn.ReLU(),
nn.Linear(128, 1)) # sigmoid를 없앰
G = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 784),
nn.Tanh())
D_loss = torch.mean((D(x) -1)**2) + torch.mean(D(G(z))**2)
G_loss = torch.mean((D(G(z))-1)**2)
# cross entropy loss와 차이 : decision boundary가 1에 가깝도록 만듬
Semi-Supervised GAN
- discriminator가 진짜, 가짜를 구분하지 않고 클래스를 구분하게 됨. 기존 10개의 클래스 + fake
- 위쪽은 discriminator쪽은 Supervised Learning, generator는 Unsupervised Learning
Auxiliary Classifier GAN(ACGAN, 2016)
- discriminator가 하는 일이 2가지
- 진짜를 구분 (sigmoid)
- 진짜든 가짜든 클래스를 구분 (softmax)
- Multi-task learning
- generator도 하는 일이 2가지
- 진짜, 가짜를 구분
- 클래스 구분
- 여태는 generator에 집중했던 GAN들이 많다면 discriminator에 집중!
Extensions of GAN
CycleGAN : Unpaired Image to Image Translation
- 얼룩말을 말로 바꾸고 여름을 겨울로 바꾸는 이미지의 스타일을 바꾸는 것
- pair example이 없이 unsupervised learning을 통해 모델을 학습
- generator는 latent vector를 받지 않고 이미지를 input으로 받음. 인코더-디코더같은 느낌
- discriminator는 얼룩말 이미지를 말 이미지로 바꾸고 싶음. 말 이미지를 주고 이 이미지가 진짜다라고 학습을 하고 generator는 얼룩말을 받고 말로 바꿈
- 얼룩말 사진을 주고 말을 뛰는 모습을 생성하면 속을 수 있음! style transfer는 모양을 유지
- 얼룩말 이미지를 다시 말로 바꿈!!
StackGAN
- 텍스트를 주고 텍스트에 해당하는 이미지를 만듬
- Task에 집중하지 않고
- 한번에 고해상도를 학습할 경우 힘들다는 한계가 존재함. 64x64를 먼저 만들고 업샘플링 과정을 겪음
Visual Attribute Transfer
- CycleGAN처럼 Input이 왼쪽, 오른쪽 2개로 들어감
User-Interactive Image Colorization
- 흑백 사진을 칼라로 변환해줌(사용자가 원하는 색으로 변경 가능)
Future of GAN
Boundary Equilibrium GAN (BEGAN)
- 해당 로스함수가 줄어들 때 학습이 잘되더라(하지만 이것은 휴리스틱하게 나온 결과)
- discriminator auto encoder 구조라 복잡한 편
Reconstruction Loss
- weight normalization을 사용
- 단점 : 로스를 얻기 위해 z값을 학습해야함
Deconvolution Checkboard Artifacts
- 좋은 upsampling을 찾아야 하지 않을까-
- Deconvolution을 많이 사용했었는데 이건 output을 불균형하게 생김(체크보드 패턴처럼 이미지가 생김)
- Resize-Convolution을 사용하면 up sampling은 룰베이스 방식으로 한 후, convolution stride=1을 하고 필터링을 여러 레이어로 쌓게 됨! - 골고루 고려할 수 있음
- BEGAN에선 resize-convlution을 사용했음
Machine Translation (Seq2Seq)
- Supervised learning은 영어로 주면 한글의 한 문장을 나옴. 이게 Supervised learning의 한계라 생각하고 GAN을 활용
Reference
카일스쿨 유튜브 채널을 만들었습니다. 데이터 분석, 커리어에 대한 내용을 공유드릴 예정입니다.
PM을 위한 데이터 리터러시 강의를 만들었습니다. 문제 정의, 지표, 실험 설계, 문화 만들기, 로그 설계, 회고 등을 담은 강의입니다
이 글이 도움이 되셨거나 의견이 있으시면 댓글 남겨주셔요.