Loading [MathJax]/jax/output/CommonHTML/jax.js

Deep Learning/GAN

[GAN] Generative Adversarial Nets - 증명

미미수 2021. 6. 29. 23:31

이번 게시물에서는 GAN의 타당성에 대한 증명 2개를 해보겠습니다.

아직 GAN의 작동 원리에 대한 이해가 깊지 않다면 아래 게시물 ↓ 을 먼저 보고 와주세요 :)

 

 

 

증명에 앞서, GAN을 비롯한 모든 생성 모델 (Generative Model)의 목적을 다시 한번 떠올려보겠습니다.

 

내가 닮고자 하는 data의 분포와 가장 유사하도록 Generator의 분포를 형성하는것이 바로 그 목적이었습니다.

다른 말로, PdataPg 거리가 최소가 될 수 있도록 만들어주는 것입니다. 이를 수식으로 표현하면, ↓

Pdata=Pg

 

따라서 GAN은 아래 두가지 를 증명합니다.

 

1. Pdata=Pg일때가 Global Optimum인가

: Pdata gernerative model distribution이 정확히 일치할 때 global optimum이며, 그때 Pdata=Pg를 갖는가?

 

2.  Algorithm1을 수행했을 시 Converge(수렴) 하는가

: 아래의 알고리즘이 실제로 global optimality Pdata=Pg인 해를 찾을 수있는가?

Algorithm 1

 


1.  Global Optimality of Pdata=Pg


Proposition 1. 어떤 고정된 G에 대하여 최적의 Discriminator D는 다음과 같다.

Pdata(x)Pdata(x)+Pg(x)

Proof. [1단계]

GAN의 value function V(G,D)는 아래와 같습니다.

 

minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

 

여기서 G는 고정이 됐다고 가정을 하면, Value Function을 아래와 같이 새로 쓸 수 있습니다.

 

=D(x)=argmaxDV(D)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

D(x)는 optimal D

 

 

위 식을 다시 한번 아래와 같이 바꿀 수 있습니다.

 

=Expdata(x)[logD(x)]+Expg(x)[log(1D(x)]

※ z를 Pz에서 샘플링해서 Generator에 주고 G(z)가 생성한 fake image는 Pg를 따르게 됩니다.

※ 이는 곧, x를 Pg에서 샘플링하는것과 같다고 할 수 있습니다.

 

 

Expectation의 정의        Exp(x)[f(x)]=xp(x)f(x)dx        에 의해 적분식으로 바꿔 줍니다.

=xpdata(x)logD(x)dx+xpg(x)log(1D(x))dx

=xpdata(x)logD(x)dx+pg(x)log(1D(x))dx

 

결국 아래 적분식을 최대화 해야 한다는 결론에 도달합니다. 적분식을 최대화하는것은 적분식 안의 식을 최대화 하는것과 마찬가지입니다.

D(x)=argmaxDV(D)=xpdata(x)logD(x)dx+pg(x)log(1D(x))dx

 

 

pdata(x)logD(x)+pg(x)log(1D(x)) 에서,

Pdata(x)=a,D(x)=y,Pg(x)=b치환을 해주면,

 

alogy+blog(1y)라는 식이 됩니다.

식을 미분해 y=0인 지점을 구합니다. 그러면 y=aa+b 에서 최대값을 가지는것을 알 수 있습니다.

 

따라서 어떤 고정된 G에 대하여 최적의 Discriminator DPdata(x)Pdata(x))+Pg(x)라는 명제가성립합니다.


Proof.[2단계]

 

위에서 구한 D를 이용해 V(D,G)를 다시 쓸 수 있습니다.

minGmaxDV(G,D)=minGV(D,G)

※G의 역할은 minimize입니다

 

V(D,G)=Expdata(x)[logD(x)]+Expg[log(1D(x)]

 

또 다시 Expectation의 정의를 적용합니다.

=xpdata(x)logPdata(x)Pdata(x)+Pg(x)dx+xpg(x)logPdata(x)Pdata(x)+Pg(x)dx

 

이 식에 -log4 + log4를 더해줍니다. ( log4 = 2log2)

=log4+log4+xpdata(x)logPdata(x)Pdata(x)+Pg(x)dx+xpg(x)logPdata(x)Pdata(x)+Pg(x)dx

 

=log4+xpdata(x)log2Pdata(x)Pdata(x)+Pg(x)dx+xpg(x)log2Pdata(x)Pdata(x)+Pg(x)dx

 

 

이제 KL Divergence라는 개념이 등장합니다.

 

KL Divergence, 어떤 두 함수의 확률분포가 얼마나 다른지 그 차이를 측정하는 식입니다.


비교하려는 확률질량함수가 A, 기준이 되는 확률질량함수가 B일때,

비교함수-기준함수,  logA(x)(logB(x))

기준확률질량함수의 확률분포 기대값을 씌워줍니다.


DKL(B||A)=ExB[logB(x)A(x)]=xB(x)logB(x)A(x)

KL Divergence를 사용하면 integral(확률질량함수)을 아래와 같이 바꿀 수 있습니다.

xpdata(x)log2Pdata(x)Pdata(x)+Pg(x)dx=KL(pdata||pdata+pg2)

 

 

log4+xpdata(x)log2Pdata(x)Pdata(x)+Pg(x)dx+xpg(x)log2Pdata(x)Pdata(x)+Pg(x)dx

 

=log4+KL(pdata||pdata+pg2)+KL(pg||pdata+pg2)

 

위의 수식은 Jenson-Shannon divergence(JSD)로 다시 정리가 될수 있기 때문에, 최종적으로, 아래의 식을 도출할 수 있습니다.

V(D,G)=log4+2JSD(Pdata||Pg)

 

 

V(D,G)를 optimize하는것은 JSD값을 minimize하는것과 마찬가지이고, (log4는 고정값이기 때문)

JSD는 Pdata와 Pg가 일치할때만 0, 그외에는 양수값을 지니기 때문에,

V(D,G)의 global minimum은 Pg = Pdata 일 때라는 것을 다시한번 확인할 수 있습니다. 

 

Pdata(x)=Pg(x)D(x)=Pdata(x)Pdata(x)+Pg(x)=12

 

 

 

 

따라서 Pdata(x)=Pg(x) 일때가 unique solution 임이 증명되었다!

 

 


2.  Convergence of Algorithm1