GAN 네트워크의 판별자와 생성자의 목표는 뭘까?
업데이트:
GAN 네트워크의 판별자와 생성자의 목표는 뭘까?
GAN 네트워크를 공부하면서 판별자와 생성자가 loss를 가져가는 방식에 대해 이해한 것에 대해 적어보려고 합니다.
GAN 네트워크 구조
그림 1 |
그림 1은 GAN 네트워크의 구조를 보여줍니다. GAN 네트워크는 입력으로 임의의 랜덤 latent vector z를 사용하고 target으로 실제 image를 사용합니다. GAN 네트워크는 2개의 model을 갖고 있습니다. 하나는 Generator, 다른 것은 Discriminator로 각각 생성자, 판별자라고 부르겠습니다.
Generator란, 임의의 latent vector z를 입력으로 받아서 Generator network를 통해 가짜 이미지 (fake)를 만들어주는 네트워크입니다.
Discriminator란, Generator가 만든 가짜 이미지(fake)와 내가 target으로 가지고 있는 진짜 이미지(real)을 구분하는 네트워크입니다.
GAN네트워크에서 사용되는 loss는 다음과 같습니다.
\[L=\frac{1}{m} \sum_{i=1}^m \{ \log D(x^i)+ \log(1-D(G(z^i))) \}\]※ m : 데이터의 수, z : latent vector, D(*) : discriminator, G(*) : generator
Summation을 사용하는 이유는 각 데이터들의 forward-propagation을 진행하고, “평균(expectation)”을 취할 것이기 때문입니다.
loss를 그저 최소로 만드는 학습을 진행하면 좋겠지만.. GAN 네트워크는 방향성이 조금 다릅니다. GAN 네트워크의 목표는 D가 loss를 최대로 만들게하고, G가 loss를 최소로 만들게끔 학습을 한다라는 것을 꼭 기억합시다.
\[\min_G\;\max_D\quad\frac{1}{m} \sum_{i=1}^m \{ \log D(x^i)+ \log(1-D(G(z^i))) \}\]D는 loss를 최대로, G는 loss를 최소로?
우선 GAN 네트워크의 학습 순서를 알아둡시다.
- 생성자로 가짜 이미지 생성
- 판별자로 판별 & 판별자 업데이트
- 생성자로 가짜 이미지 생성
- 판별자로 판별 & 생성자 업데이트
- 1.로 돌아감
D는 loss를 최대로
다시 본론으로 돌아와, discriminator는 loss를 최대로 만든다는게 무슨 말 일까요? 이것에 대해 알아봅시다.
판별자는 입력된 이미지가 진짜 이미지라고 판단할 수록 1에 가까운, 가짜 이미지라고 판단할 수록 0에 가까운 확률 값을 반환합니다. 이 개념을 갖고 loss 식을 판별자의 입장에서 봐봅시다. 실제 갖고 있는 이미지를 $x$, 생성자로 만들어진 가짜 이미지를 $G(z)$라고 합시다.
그럼 목표 식을 D에 대해서만 가져와봅시다.
\[\max_D\quad\frac{1}{m} \sum_{i=1}^m \{ \log D(x^i)+ \log(1-D(G(z^i))) \}\]다음의 경우들에 따라 loss는 최대,최소 값을 가지게 됩니다.
- 최대 : 실제 이미지 $x$를 진짜로 판별(1) 그리고 가짜 이미지(G(z))를 가짜로 판별(0) -> $loss =0$
- 최소 : 실제 이미지 $x$를 가짜로 판별(0) 또는 가짜 이미지(G(z))를 진짜로 판별(1) -> $loss = -\infty$
Discriminator는 loss가 최대가 되어야, 실제 이미지를 진짜로, 가짜 이미지를 가짜로 판별하면서 자신의 존재 의미를 찾게됩니다. 그러니까 Discriminator는 $-\infty$에서 0으로 gradient ascending이 이루어지는 loss를 최대로 만들자는 목표를 갖게 됩니다.
G는 loss를 최소로
그럼 generator는 loss를 최소로 만든다는게 무슨 말 일까요? 이것에 대해 알아봅시다.
생성자는 자신이 만든 가짜 이미지를 판별자가 진짜라고 속게끔 하는게 목표입니다. 즉, $D(\ast)$를 항상 1으로 만들어버리고 싶어하죠! 이 개념을 갖고 loss 식을 생성자의 입장에서 봐봅시다. 생성자와 관련된 식은 우변의 두번째 항 밖에 없습니다.
\[\min_G\quad\frac{1}{m} \sum_{i=1}^m \log(1-D(G(z^i)))\]생성자가 이미지를 아주 잘~ 만들어서 $D(G(z^i))$를 1로 만들어버리면 loss는 $-\infty$가 되겠네요. 반대로 판별자가 아주 유능해서 가짜이미지라고 판별해버리면 loss는 0이 되구요. 그래서 Generator는 $0$에서 $-\infty$로 gradient descent가 이루어지는 loss를 최소로 만들자는 목표를 갖게 됩니다.
이제서야 GAN네트워크의 학습과 판별자,생성자의 존재 의미를 알게 되었습니다.