生成对抗神经网络GAN入门,有这篇就够了🎉

2016年某日,有人在Quora (opens new window)上抛出问题: 在深度学习领域有哪些正在或将要爆发的大突破? (opens new window)
不曾料到Facebook AI首席科学家杨立昆 (opens new window)对这个问题做出了详细的回答。他提到:
在深度学习领域最近有太多的发展让我没有办法在这里一一列举。
在我看来最重要的是对抗训练(也被叫做生成对抗神经网络, 简称GAN). 最开始是由Goodfellow (opens new window)在学生时期提出来的。
我想这个(指GAN)和那些正在被提出的变形是过去10年来最有趣的想法。
可以看到他给了GAN非常高的评价🎉,一方法说最重要的是GAN,另外一方面更是说GAN是过去10年来最有趣的想法。
杨立昆最重要的贡献是在图像处理领域,他提出了卷积神经网络(CNN), 因此也被叫做卷积神经网络之父🧔🏻。我猜想他
对GAN给出这么高的评价很有可能是GAN在图像领域有着广泛的应用。
杨立昆(Yann Lecun)正在向学习GAN的我们微笑 例如,文章开头的图片就是GAN应用的一个很好的例子。这是一个叫做Deep Dream Generator (opens new window)
的网页应用,用户JosieArt上传一张大象🐘图片,然后选择了图片左下角的风格,通过训练好的生成器
就可以生成相应风格的图片。得到的效果是否还很不错呢?
当然这里提到的只是GAN应用的冰山一角,在Gans Awesome Applications (opens new window)
上可以查看大量GAN的应用。
此刻或许大家就会好奇🤔,到底什么是GAN?
什么是GAN🤔️
GAN有两个神经网络🕸。
第一个叫判别器(Discriminator),记做D(Y)。它得到输入Y(比如一张图)后输出一个
值,这个值表示了Y看起来是否"真实"。D(Y)可以看作某种能量函数,当Y是真实样本时,函数的值接近0,
反之,当图片Y的噪声很大或者很奇怪时,函数值为正。
Generative adversarial network(GAN) 另一个网络叫做生成器(Generator), 记为G(Z)。
这里的Z通常是从一个简单分布(例如高斯分布)随机抽样得到的向量,生成器G(Z)的作用是生成图片,这些
生成的图片会被用来训练判别器D(Y)(给真实图片较低的值,其他的图片较高的值)。
判别器D和生成器G的训练时交替进行的; 训练D的时候,G的参数保持不变,而训练G的时候,D的参数是固定的。
通常生成器G要学会生成真实的图片要花费更多的时间,而要让判别器D学会像不像真实照片,相对较为容易。因此,通常我们大量地训练生成器后,才会"弱弱地"训练一下判别器。
具体来说:训练D的过程中,给它一张真实的图片,使其调整参数输出较低的值;再给它一张G生成的图片,让它调整参数
输出较大的值D(G(Z))。
另一方面,在训练G的时候, 它会调整内部的参数使得它生成的图片越来越真实。也就是它一直在优化使得它产生的
图片能够骗过D, 想要让D认为它生成的图片是真实的。
也就是说,对这些生成的图片,G想要最小化D的输出,而D想要最大化D的输出。呈现出对抗的样子,所以这样的训练就叫做对抗训练(adversarial training), 也叫做GAN。
训练目标🎯
就拿生成图片来说清楚GAN背后是在做什么。日剧《半分、青い》 (opens new window)讲述了一个活泼小女孩,玲爱,在一次生病后造成左耳失聪,却没因此而气馁,在父母与青梅竹马的鼓励下继续开朗的活着,成为漫画家的故事。
你看,她开心地加入GAN训练项目 说这个故事是因为我们这里需要漫画家😉。假设我们有很多玲爱画的卡通头像。每一张图可以看做是高维空间中的一个点x, 由于所有的这些卡通头像都是她画的,因此认为所有的这些点都服从玲爱风格的某个分布Pdata(x);也就是说,只要从这个分布里面随便采样一个点x, 那么这个点x对应的头像都应该和她的风格一样。
因此, 我们的目的是要找到一个生成器,它所产生的样本x对应的分布PG(x)和玲爱对应的分布Pdata(x)越接近越好。那如何衡量分布的接近程度呢?这是一个问题,我们继续看。
训练思路📕
训练GAN可以分成三步:
把大象放进冰箱的3个步骤
步骤1
搭建一个生成器神经网络G, 网络所有参数用θ表示。G目的是用来生成图片, 我们认为生成的图片样本x都服从生成器对应的一个分布PG(x;θ)。
步骤2
在现有头像数据库中抽出n张头像, 对应高维空间中n个点{x1,x2,...,xn}。"抽"的这个动作相当于在分布Pdata(x)里采样。我们既然能够抽到到{x1,x2,...,xn}, 那就是说Pdata(x1),Pdata(x2),...,Pdata(xn)这些概率值很大。前面已经提到, 生成器训练的目标是让PG(x)
和Pdata(x)越接近越好。因此我们希望PG(x1;θ),PG(x2;θ),...,PG(xn;θ)这些概率值每一个都很高, 换句话说就是想要∏k=1nPG(xk;θ)的值越大越好。
提示
x1被抽到了,那么一定是Pdata(x1)的概率高才会被抽到。我们通常假设:一件事情发生了,就说明这件事发生的
概率比较高。
步骤3
训练网络找到参数
θ∗=argθmaxk=1∏nPG(xk;θ)
其中, ∏k=1nPG(xk;θ)叫做样本的Likelihood。 可以证明:
argθmaxk=1∏nPG(xk;θ)=argθmin[KL(Pdata∣∣PG)]
这里KL指的是KL Divergence, 它可以表示两个分布的接近程度。
上面这个式子说的是最大化Likelihood和最小化KL Divergence是同一个意思。因此这个步骤变为:训练网络找到参数
θ∗=argθmin[KL(Pdata∣∣PG)]
证明: 最大化Likelihood = 最小化KL Divergence
argmaxθ∏k=1nPG(xk;θ)=argmaxθlog∏k=1nPG(xk;θ)
=argmaxθ∑k=1nlogPG(xk;θ), 由于{x1,x2,...,xn}都来自于Pdata(x)
这个分布, 因此:
≈argmaxθEx∼Pdata(x)[logPG(x;θ)]
=argmaxθ[∫xPdata(x)logPG(x;θ)dx] 下一步减一个常数,对找θ∗没有影响。
=argmaxθ[∫xPdata(x)logPG(x;θ)dx−∫xPdata(x)logPdata(x)dx]
=argmaxθ[∫xPdata(x)logPdata(x)PG(x;θ)dx]
=argmaxθ[−KL(Pdata∣∣PG)]
=argminθ[KL(Pdata∣∣PG)]
证毕
具体步骤
搭建生成器
等到我们设计的生成器神经网络G训练好了之后,我们就可以利用它来生成玲爱风格的卡通头像了。这个神经网络具体应该是什么样子的呢? 我们看下面这张图:
GAN中生成器(Generator)的结构 输入是128维的一个向量, 这些输入的值可以控制输出头像的一些特征,例如头发的颜色,头发的长短, 肤色,性别等等。
make.girls.moe (opens new window)是一个在线的卡通头像生成项目,去玩一下就很容易理解这里神经网络G的输入是什么了。输出构成一张图片,这里的64x64个像素点最后可以拼成一张灰度图。
Tip
这里的输入可以看做是128维空间中的一个点;输出是4096维空间中的一个点; 输入数据我们从一个固定的分布中去采样(例如:128维的高斯分布);
我们期望神经网络(G)的输出数据(图片)的分布PG(x)能够和Pdata(x)的分布越接近越好。
确认目的
训练最终目的: 找到
G∗=argGminDiv(PG,Pdata)
就是说我们想要找到一个最好的生成器G∗, 它能够使得PG和Pdata的某种Divergence越小越好(前面提到的KL是一种具体的Divergence)。
有目标后,一步一步走下去。 自然地,你就会问Div(PG,Pdata)我们该怎么算?
计算Div(PG,Pdata)
任何一种Divergence都是有公式的,我在文章"什么是熵"中有提到。原则上直接带入公式就可以算出来了。但是这里的问题是我们不知道PG,Pdata的表达式是什么。
一方面,尽管我们有玲爱画的漫画头像库,但是我们也不知道她画的头像服从什么分布, 即Pdata未知;
另一方面,即使可以利用生成器生成很多的头像,但是我们也不知道这些生成的头像服从什么分布, 即PG未知。
那应该如何计算两个分布的Divergence呢? 我们的手上只有两个分布的样本? 可以直接比较样本吗? 答案是肯定的。这件事漫画导师秋風羽織 (opens new window)
可以帮我们大忙。
《半分、青い。》中的秋風羽織正在查看两类图片 我们将玲爱和生成器G画的头像分别拿去给秋风先生过目,让他说说谁画得更真实。毕竟是玲爱是秋风先生
的得意弟子,因此他总是想也不想就直接说玲爱的画的更好,还总是说生成的差太多。生成器器没有办法,为了得到秋风先生的的夸赞它就会像玲爱学习。这样的结果是机器或许真的可以变得和玲爱一样优秀。
我们的期待
机器真的可以变得和玲爱一样优秀;也就是Pdata(x)和PG(x)可以很接近;也就是Div(PG,Pdata)很小。
可以看到,在训练生成器的过程中, 有一个举足轻重的角色: 秋風羽織先生。
在训练过程中,判别器可能会产生上万张图片,难道你真的准备让秋風先生去做这么无聊的事?😳
当然不行,因此我们需要另外一个角色来代替秋風先生,那便是判别器。
判别器就做一件事情: 给生成器的生成的头像打低分,给玲爱画的打高分。
开始时判别器的技能当然比不过秋風先生,每当生成器升级之后, 判别器也需要更新技能才能够很好地判断。
好在有很好的资料(训练数据,来自两个分布的样本)可以供它升级(更新网络参数参数)。因此我们希望最终可
以得到一个像秋風先生那般厉害的判别器D∗。
不能够直接算Div(PG,Pdata), 我们只能间接地训练一个判别器来判断一张图片有多像玲爱画的,
很像判别器输出的值就接近1, 不像就接近0; 通过判别器的输出D,我们再构造一个函数V, 这个函数就
反映了Div(PG,Pdata)。
训练判别器
要反映Div(PG,Pdata), 就需要有一个判别器网络。
判别器的结构是什么样的呢?
GAN中判别器(Discriminator)的结构 可以看到,判别器的输入是图片,输出是一个数,这个数反映的是这张图有多像玲爱画的。我么如何才能得到像秋風先生那样厉害的判别器D∗呢? 我们还需要做一件事,使用判别器构造出一个二分类器,即是否判断输入图是否为玲爱的。
分类器的输出:
V=Ex∼Pdata[logD(x)]+Ex∼PG[log(1−D(x))]
解释
V的表达式可以这样理解,在训练判别器计算分类器输出时:
- 当图片x来自玲爱数据集时,x带入Ex∼Pdata[logD(x)]计算。
- 当图片x来自生成器数据集时,x带入Ex∼PG[log(1−D(x))]计算。
可以看到在1中V和判别器D相关; 在2中,由于图片x来自生成器,因此V还与G有关,因此V写作
V(D,G)
于是判别器最优解,
D∗=argDmaxV(D,G)
通过调整参数,尝试找到最好的判别器D∗,它使得V的值越大越好; 要使V更大,
判别器就要学会:
- 给玲爱的画高分,即让D(x)的值越接近1越好;
- 给生成的画低分,即让D(x)的值越接近0越好。
新的问题是:
V(D,G)=Ex∼Pdata[logD(x)]+Ex∼PG[log(1−D(x))]
该怎么算?
使用均值的方式去近似期望,于是
V~=m1i=1∑mlogD(xi)+m1i=1∑mlog(1−D(x~i))
其中
- xi来自玲爱数据集, 标签为1,
- x~i来自生成器数据集, 标签为0。
算法总结
始终要记得GAN我们要的是G, 因此第一步搭建G网络,第二步是准备训练数据,第三步是训练。
但是训练的过程我们需要一个判别器D来帮助我们分清好坏。因此我们要顺便训练一个D。
总结算法如下:
GAN算法
- 初始化D,G的参数分别为θd,θg
- 循环训练:
- 训练D, k次:
- 从分布Pdata(x)中随机采样m个 {x1,x2,...,xm}
- 从已知分布Pprior(z)中随机采样m个 {z1,z2,...,zm}
- 使用x~i=G(zi)获得生成数据 {x~1,x~2,...,x~m}
- 更新D的参数θd来最大化
- V~=m1∑i=1mlogD(xi)+m1∑i=1mlog(1−D(x~i))
- θd←θd+η∇V~(θd)
- 训练G, 1次:
- 重新从已知分布Pprior(z)中随机采样m个 {z1,z2,...,zm}
- 更新G的参数来最小化
- V~=m1∑i=1mlog(1−D(x~i))
- θg←θg−η∇V~(θg)
GAN在训练判别器时,生成器的参数固定不变;在训练生成器时,判别器的参数固定不变。GAN的核心思想在这篇文章中算是介绍得差不多了。当然最初的GAN也有很多不足之处,例如很难训练,不容易收敛等问题。因此又有很多人对原始的算法进行优化,出现了各种各样衍生的GAN,我们以后有机会在聊一聊其他的GAN。
希望这篇文章能够帮助到你。我是阿梁,我在机器学习之路way2ml (opens new window)分享高效工具, 编程,AI的内容。下次见!
参考:
- 首页图片来自JosieArt Deep Dream Generator (opens new window)
- 国立台湾大李宏毅老师GAN课程 GAN Lecture 4 (2018): Basic Theory (opens new window)
- 维基百科: Generative adversarial network (opens new window)
- Medium: Understanding Generative Adversarial Networks (GANs) (opens new window)
- GAN示意图: Generative Adversarial Network (GAN) (opens new window)
- 维基百科:半分、青い (opens new window)
- 步骤示意图: 一个大象放进冰箱的过程 (opens new window)
- 火熱的生成對抗網路(GAN),你究竟好在哪裡 (opens new window)
- Stackoverflow image with caption - vuepress (opens new window)
- html使用简单标签改变字体(加粗、斜体...) (opens new window)