Generative Adversarial Networks
Generative Adversarial Networks
은 두 개 이상의 신경망이 서로를 향하게 하고, 서로 대항하듯이 훈련하게 함으로써, 결과적으로 생성 모델(generative model)
을 산출해낸다.
- GAN의 이점
- 데이터가 한정된 상황에서도 일반화(Generalization)를 할 수 있다.
- 작은 데이터셋을 가지고도 새로운 장면을 생각할 수 있다.
- 모조 데이터(simulated data)를 더욱 진짜처럼 보이게 할 수 있다.
Generative Modeling & Discriminative Modeling
판별 모델링(Discriminative Modeling)
- 그림을 살펴본 다음에 해당 그림의 style을 정하는 일은 무엇인가를 판단하는 일이다.
- 머신러닝에서 판별 모델링을 수행하는 방식은 다음과 같다.
- 데이터 내 각 부분(divisions)을 이해하기 위해 합성곱 계층을 만들거나, 기타 학습된 특징들을 사용하는 머신러닝 모델 구성
- 훈련 집합과 검증 집합이 모두 포함된 데이터셋 수집
- 데이터셋을 통한 머신러닝 모델 훈련
- 머신러닝 모델을 통해 데이터 점이 속하는 class를 예측
- 데이터의 분포(distributions)에 대한 계급 간의
경계 조건(boundary conditions)
을 학습한다.
- 데이터가 많을수록 판별 모델의 성능은 좋아지며, 레이블(label)이 지정된 데이터를 사용해야 한다. 즉,
지도 학습(supervised learning)
이다.
생성 모델링(Generative Modeling)
- style에 대한 지식을 바탕으로 그림의 style을 결정한 다음, 해당 그림을 재현한다.
- 머신러닝에서 생성 모델링을 수행하는 방식은 다음과 같다.
- 다양한 그림의 style을 복제(reproduce)하는 방법을 학습하는 머신러닝 모델부터 작성한다.
- 훈련 데이터셋과 검증 데이터셋을 수집한다.
- 해당 데이터를 통해 머신러닝 모델을 훈련한다.
- 유사도(similarity)라고 하는 metric을 사용해 모델에서 style을 재현하는 기능을 확인한다.
- 주어진 입력의 분포에 대한 class들의 분포를 모델링한다. 해당 분포를 추정하기 위해 각 class에 대한 확률 모델을 만든다.
- 훈련 도중 알아서 레이블을 학습하게 되므로 레이블이 없는 데이터를 사용해도 된다. 즉,
비지도 학습(unsupervised learning)
이다.
위조 지폐 예시
GAN의 아키텍처를 설명하는데 가장 많이 활용되는 것이 위조 지폐의 예시이다. 생성기(Generator)는 위조범이라 할 수 있고, 판별기(Discriminator)는 FBI 요원이라고 할 수 있다. 위조범은 FBI 요원의 검사를 통과할 수 있는 위조 지폐를 만드는 새로운 방법을 끊임없이 모색하고, FBI 요원은 위조 지폐와 진짜 지폐를 최대한 구별하는 것이 목적이다. Generator와 Disriminator의 목표를 정리하면 다음과 같다.
생성기(Generator)
의 목적 : 진짜 같은 가짜 출력을 생성해 내어 판별기가 실수로 진짜와 가짜를 잘못 분류하게 되는 가능성을 극대화한다.판별기(Discriminator)
의 목적 : 판별기가 진짜 이미지와 생성된 이미지를 구별하지 못하는 확률 목표인 0.5를 달성하게 되기까지 최대화한다.
GAN Architecture
GAN 프레임워크에서 생성기는 판별기와 동시에 훈련을 시작해야 하지만, 실제로는 판별기가 이미지를 분류할 수 있어야 하므로, 대체로 훈련을 시작하기 전에 몇 개의 에포크만큼 먼저 판별기부터 훈련해야 한다. 위 그림에서의 아키텍처는 생성기와 판별기, 손실 함수로 구성된다.
Generator Architecture
생성기 아키텍처의 구성 요소에는 잠재 공간, 생성기, 이미지 생성 부분이 있다. 생성기는 다음과 같은 훈련 과정이 완료된 후에 표현하려고 하는 출력(이미지)을 생성한다.
- 생성기는 잠재 공간(latent space, 가상의 데이터로 채워진 벡터 공간)에서 표본(samples)을 추출해 잠재 공간과 출력 간의 관계를 생성하는 것이 역할이다.
- 입력(잠재 공간)에서 출력(대부분 이미지)을 향해 가는 신경망을 만든다. -> 생성기
- 한 모델 안에서 생성기와 판별기가 서로 연결해 적대 모드(adversarial mode)를 취하게 함으로써 생성기를 훈련한다.
- 생성기의 훈련을 끝낸 뒤에는 생성기를 추론에 사용할 수 있다.
이러한 각 부분에 대한 코드의 구조를 클래스로 정의하면 다음과 같다.
class Generator:
def __init__(self):
self.initVariable = 1
def lossFunction(self): # 모델을 훈련할 때 쓸 사용자 정의 손실 함수
return
def buildModel(self): # 신경망 모델 구성
return
def trainModel(self, inputX, inputY):
return
Discriminator Architecture
판별기는 앞서 설명했다싶이 생성기가 '출력한 내용'과 '진짜 이미지'를 놓고 둘 중에 어느게 진짜인지 또는 가짜인지 여부를 결정하는 데 사용된다. 판별기 역할을 담당할 신경망으로는 일반적으로 간단한 아키텍처로 이뤄진 합성곱 신경망(convolution neural network, CNN)
을 사용한다. 판별기 아키텍처의 구성 요소로는 진짜 이미지, 생성기가 출력해 낸 이미지, 판별기, 출력(진짜 또는 가짜)가 있다.
판별기를 구축하는 과정은 다음과 같다.
- 진짜와 가짜를 분류(이진 분류)하는 데 쓸 합성곱 신경망을 만든다.
- 진짜 데이터로만 구성된 데이터셋을 만들고, 생성기를 사용해 가짜 데이터로만 구성된 데이터셋도 만든다.
- 진짜 데이터와 가짜 데이터를 사용해 판별기 모델을 훈련한다.
- 생성기를 훈련함으로써 훈련된 판별기와 서로 균형을 잡는 방법을 학습하게 된다. 즉, 판별기가 너무 뛰어나게 되면 생성기가 발산(diverges)하게 된다는 점을 이용한다는 뜻이다.
궁극적으로 판별기는 진짜 이미지가 진짜인지 가짜인지, 생성된 이미지가 진짜인지 가짜인지를 평가할 것이다. 초기에 진짜 이미지는 metric 점수를 높이는 반면에, 생성된 이미지는 metric 점수를 낮춘다. 점차적으로, 생성기가 점점 더 진짜에 가까운 이미지를 생성해 낼 것이기 때문에 판별기는 생성된 이미지와 진짜 이미지를 구분하는 데 어려움을 겪게 된다. 판별기는 모델을 구축하는 일에 의존하게 될 것이고, 잠재적으로는 초기 손실 함수에 의존할 것이다.
판별기에 해당하는 클래스 템플릿은 다음과 같다.
class Discriminator:
def __init__(self):
self.initVariable = 1
def lossFunction(self):
return
def buildModel(self):
return
def trainMOdel(self, inputX, inputY):
return
Loss function
각 신경망에는 훈련하는 데 필요한 어떤 구조 요소(structural components)가 있다. 신경망에서는 주어진 문제 집합에 대해 훈련 과정 중에 가중치를 조절함으로써 손실 함수가 최적화되게 한다. 즉, 손실 함수를 metric으로 삼아 신경망을 적합(fit)시켜 나간다. 신경망이 좋은 결과를 산출해 내면서도 수렴(converges)하도록 하려면 신경망을 사용하는 목적에 맞게 손실 함수(loss function)
를 선택하는 것이 필수적이다.
앞서 보았듯 GAN은 Generator와 Discriminator 두 개의 모델이 서로 경쟁하며 성능을 향상시키는 알고리즘이다. 따라서 그에 맞는 손실 함수는 다음과 같은 minimax loss
이다.
각 구성 요소를 살펴보면 다음과 같다.
- D : 판별기, G : 생성기
- min G : (생성기 입장에서) V의 값을 최소화 / max D : (판별기 입장에서) V의 값을 최대화
- x~Pdata(x) : 진짜 데이터 집합에서 샘플링한 데이터(진짜 데이터)
- z~PZ(z) : 정규분포를 사용하는 임의의 노이즈 집합에서 샘플링한 데이터(가짜 데이터)
- Z : 잠재 공간 / z : 노이즈
- D(x) : x가 판별 모델에 입력되었을 때 해당 데이터를 판별 모델이 진짜 데이터라고 판단할 확률(0~1의 값)
- G(z) : z(노이즈)가 생성 모델에 입력되었을 때, 생성 모델이 생성한 가짜 데이터
- D(G(z)) : G(z)(가짜 데이터)가 판별 모델에 입력되었을 때, 판별 모델이 해당 데이터를 진짜 데이터라고 판단할 확률(0~1의 값)
- 1-D(G(z)) : 판별 모델이 해당 데이터를 가짜 데이터라고 판단할 확률(1-진짜 데이터라고 판단할 확률)
이제 각 식을 살펴보면 다음과 같다.
- 첫 번째 항 : 진짜 데이터 x를 판별 모델에 입력했을 때, 판별 모델이 진짜 데이터라고 판단할 확률에 로그를 취한 값
- 두 번째 항 : 가짜 데이터 G(z)를 판별 모델에 입력했을 때, 판별 모델이 가짜 데이터라고 판단할 확률에 로그를 취한 값
이제 손실 함수를 각각 판별기와 생성기의 입장에서 바라본다.
- 판별기(D)의 입장에서
- 첫번째 항에선 진짜 데이터가 들어갔을 때 D(x) = 1이 되야하므로 log1 = 0의 값이 된다.
- 두번째 항에선 가짜 데이터가 들어갔을 때 1-D(G(z)) = 1 - 0 = log1 = 0의 값이 된다.
- 따라서 판별기(D)는 해당 손실 함수에서 최댓값 0을 갖기 위해 학습하게 된다.
- 이 때 판별기(D)는 진짜 혹은 가짜에 대한 학습을 하므로
Binary Cross Entropy Loss
를 사용한다.
- 생성기(G)의 입장에서
- 첫번째 항에선 생성기가 관여할 수 있는 부분이 없다.
- 두번째 항에서 1-D(G(z))의 값이 1이 되야하므로 log(1-1) = log0 = -∞가 된다.
- 따라서 생성기(G)는 해당 손실 함수에서 최솟값 -∞를 갖기 위해 학습하게 된다.
결론적으로 판별기(D)는 진짜 데이터와 가짜 데이터를 잘 구별(D(x)=1, 1-D(G(z))=0)할 수 있게끔 판별하기 위해 V(D,G)를 최대화하는 방향으로 학습하게 된다. 또한 생성기(G)의 경우 1-D(G(z))를 1로 만들기 위해, 즉 V(D,G)의 값을 최소화할 수 있게끔 학습하게 된다. 이러한 적대적인 학습이 곧 GAN과 minimax loss function의 핵심이다.
손실 함수용 템플릿 클래스는 다음과 같다. 해당 클래스는 각 부분적 모델들을 구축할 때 사용할 여러 가지 손실 함수를 나타낸다.
class Loss:
def __init__(self):
self.initVariable = 1
def lossBaseFunction1(self):
return
def lossBaseFunction2(self):
return
def lossBaseFunction3(self):
return
GAN 학습과정
추가적으로 Generative Adversarial Networks, Ian Goodfello et al, NIPS 2014 논문에서 GAN의 학습과정에 대한 설명을 다음과 같이 나타냈다. 검은색 선은 data generating distribution(real), 즉 진짜 데이터의 분포를 나타내며, 파란색 선은 discriminative distribution, 녹색 선은 generative distribution(fake)를 나타낸다.
- (a) : 학습 초기 과정으로, real 데이터와 fake 데이터의 분포가 전혀 다르다. 또한 판별기(D)의 성능 또한 좋지 못하다.
- (b) : 판별기(D)의 성능이 올라가 real 데이터와 fake 데이터에 대한 확률을 분명하게 판별하고 있다. 하지만 fake 데이터의 분포는 real 데이터와 차이가 있다.
- (c) : 판별기(D)가 어느 정도 학습을 한 후엔 생성기(G)가 실제 데이터의 분포를 어느 정도 반영하여 판별기(D)가 구별하기 힘들게 학습을 시작한다.
- (d) : 계속된 반복 학습을 통해 real 데이터와 fake 데이터의 분포가 거의 유사해져 판별기(D)가 구분할 수 없을 때까지 생성기(G)가 학습을 하게 된다. 그 결과 D가 real 데이터와 fake 데이터에 대한 구분을 할 수 없게 되고, 이 때의 확률이 1/2가 된다.
- 이러한 과정을 통해 real 데이터와 매우 유사한 fake 데이터를 생성기(G)가 생성하게 되며, 이것이 GAN의 최종 결과물이다.
Reference
- https://jun-story.tistory.com/31
- https://sites.google.com/site/aidysft/generativeadversialnetwork
- https://velog.io/@hyebbly/Deep-Learning-Loss-%EC%A0%95%EB%A6%AC-1-GAN-loss
- https://github.com/lilly9117/Tobigs_13/blob/master/week9_gan_assignment1_%EC%B5%9C%ED%98%9C%EB%B9%88.ipynb