# 生成对抗神经网络GAN入门,有这篇就够了🎉

2016年某日,有人在Quora (opens new window)上抛出问题: 在深度学习领域有哪些正在或将要爆发的大突破? (opens new window) 不曾料到Facebook AI首席科学家杨立昆 (opens new window)对这个问题做出了详细的回答。他提到:

在深度学习领域最近有太多的发展让我没有办法在这里一一列举。 在我看来最重要的是对抗训练(也被叫做生成对抗神经网络, 简称GAN). 最开始是由Goodfellow (opens new window)在学生时期提出来的。 我想这个(指GAN)和那些正在被提出的变形是过去10年来最有趣的想法。

可以看到他给了GAN非常高的评价🎉,一方法说最重要的是GAN,另外一方面更是说GAN是过去10年来最有趣的想法。 杨立昆最重要的贡献是在图像处理领域,他提出了卷积神经网络(CNN), 因此也被叫做卷积神经网络之父🧔🏻。我猜想他 对GAN给出这么高的评价很有可能是GAN在图像领域有着广泛的应用。

杨立昆(Yann Lecun)正在向学习GAN的我们微笑

例如,文章开头的图片就是GAN应用的一个很好的例子。这是一个叫做Deep Dream Generator (opens new window) 的网页应用,用户JosieArt上传一张大象🐘图片,然后选择了图片左下角的风格,通过训练好的生成器 就可以生成相应风格的图片。得到的效果是否还很不错呢?

当然这里提到的只是GAN应用的冰山一角,在Gans Awesome Applications (opens new window) 上可以查看大量GAN的应用。

此刻或许大家就会好奇🤔,到底什么是GAN?

# 什么是GAN🤔️

GAN有两个神经网络🕸。

第一个叫判别器(Discriminator),记做D(Y)D(Y)。它得到输入YY(比如一张图)后输出一个 值,这个值表示了YY看起来是否"真实"。D(Y)D(Y)可以看作某种能量函数,当YY是真实样本时,函数的值接近0, 反之,当图片YY的噪声很大或者很奇怪时,函数值为正。

Generative adversarial network(GAN)

另一个网络叫做生成器(Generator), 记为G(Z)G(Z)。 这里的ZZ通常是从一个简单分布(例如高斯分布)随机抽样得到的向量,生成器G(Z)G(Z)的作用是生成图片,这些 生成的图片会被用来训练判别器D(Y)D(Y)(给真实图片较低的值,其他的图片较高的值)。

判别器DD和生成器GG的训练时交替进行的; 训练DD的时候,GG的参数保持不变,而训练GG的时候,DD的参数是固定的。 通常生成器GG要学会生成真实的图片要花费更多的时间,而要让判别器DD学会像不像真实照片,相对较为容易。因此,通常我们大量地训练生成器后,才会"弱弱地"训练一下判别器。

具体来说:训练DD的过程中,给它一张真实的图片,使其调整参数输出较低的值;再给它一张GG生成的图片,让它调整参数 输出较大的值D(G(Z))D(G(Z))

另一方面,在训练GG的时候, 它会调整内部的参数使得它生成的图片越来越真实。也就是它一直在优化使得它产生的 图片能够骗过DD, 想要让DD认为它生成的图片是真实的。

也就是说,对这些生成的图片,GG想要最小化DD的输出,而DD想要最大化DD的输出。呈现出对抗的样子,所以这样的训练就叫做对抗训练(adversarial training), 也叫做GAN

# 训练目标🎯

就拿生成图片来说清楚GAN背后是在做什么。日剧《半分、青い》 (opens new window)讲述了一个活泼小女孩,玲爱,在一次生病后造成左耳失聪,却没因此而气馁,在父母与青梅竹马的鼓励下继续开朗的活着,成为漫画家的故事。

你看,她开心地加入GAN训练项目

说这个故事是因为我们这里需要漫画家😉。假设我们有很多玲爱画的卡通头像。每一张图可以看做是高维空间中的一个点xx, 由于所有的这些卡通头像都是她画的,因此认为所有的这些点都服从玲爱风格的某个分布Pdata(x)P_{data}(x);也就是说,只要从这个分布里面随便采样一个点xx, 那么这个点xx对应的头像都应该和她的风格一样。

GAN的目标

让机器学会生成玲爱风格的卡通头像。

因此, 我们的目的是要找到一个生成器,它所产生的样本xx对应的分布PG(x)P_{G}(x)和玲爱对应的分布Pdata(x)P_{data}(x)越接近越好。那如何衡量分布的接近程度呢?这是一个问题,我们继续看。

# 训练思路📕

训练GAN可以分成三步:

把大象放进冰箱的3个步骤

步骤1
搭建一个生成器神经网络GG, 网络所有参数用θ\theta表示。GG目的是用来生成图片, 我们认为生成的图片样本xx都服从生成器对应的一个分布PG(x;θ)P_{G}(x; \theta)

步骤2
在现有头像数据库中抽出nn张头像, 对应高维空间中nn个点{x1,x2,...,xn}\{x_1,x_2,...,x_n \}。""的这个动作相当于在分布Pdata(x)P_{data}(x)里采样。我们既然能够抽到到{x1,x2,...,xn}\{x_1,x_2,...,x_n\}, 那就是说Pdata(x1)P_{data}(x_1),Pdata(x2)P_{data}(x_2),...,Pdata(xn)P_{data}(x_n)这些概率值很大。前面已经提到, 生成器训练的目标是让PG(x)P_{G}(x)Pdata(x)P_{data}(x)越接近越好。因此我们希望PG(x1;θ),PG(x2;θ),...,PG(xn;θ)P_{G}(x_1; \theta),P_{G}(x_2; \theta),...,P_{G}(x_n; \theta)这些概率值每一个都很高, 换句话说就是想要k=1nPG(xk;θ)\prod_{k=1}^n P_{G}(x_k; \theta)的值越大越好。

提示

x1x_1被抽到了,那么一定是Pdata(x1)P_{data}(x_1)的概率高才会被抽到。我们通常假设:一件事情发生了,就说明这件事发生的 概率比较高。

步骤3
训练网络找到参数

θ=argmaxθk=1nPG(xk;θ)\theta^* = \arg \max_{\theta} \prod_{k=1}^n P_{G}(x_k; \theta)

其中, k=1nPG(xk;θ)\prod_{k=1}^n P_{G}(x_k; \theta)叫做样本的Likelihood。 可以证明:

argmaxθk=1nPG(xk;θ)=argminθ[KL(PdataPG)]\arg \max_{\theta} \prod_{k=1}^n P_{G}(x_k; \theta) = \arg \min_{\theta} [KL(P_{data}||P_{G})]

这里KLKL指的是KL Divergence, 它可以表示两个分布的接近程度。 上面这个式子说的是最大化Likelihood和最小化KL Divergence是同一个意思。因此这个步骤变为:训练网络找到参数

θ=argminθ[KL(PdataPG)]\theta^* = \arg \min_{\theta} [KL(P_{data}||P_{G})]

证明: 最大化Likelihood = 最小化KL Divergence

argmaxθk=1nPG(xk;θ)=argmaxθlogk=1nPG(xk;θ)\arg \max_{\theta} \prod_{k=1}^n P_{G}(x_k; \theta) = \arg \max_{\theta} \log \prod_{k=1}^n P_{G}(x_k; \theta)

=argmaxθk=1nlogPG(xk;θ)= \arg \max_{\theta} \sum_{k=1}^n \log P_{G}(x_k; \theta), 由于{x1,x2,...,xn}\{x_1,x_2,...,x_n\}都来自于Pdata(x)P_{data}(x) 这个分布, 因此:

argmaxθExPdata(x)[logPG(x;θ)]\approx \arg \max_{\theta} \mathbb{E}_{x \sim P_{data}(x)}[\log P_{G}(x; \theta)]

=argmaxθ[xPdata(x)logPG(x;θ)dx]= \arg \max_{\theta} [\int_x P_{data}(x)\log P_{G}(x;\theta)dx] 下一步减一个常数,对找θ\theta ^*没有影响。

=argmaxθ[xPdata(x)logPG(x;θ)dxxPdata(x)logPdata(x)dx]= \arg \max_{\theta} [\int_x P_{data}(x)\log P_{G}(x;\theta)dx - \int_x P_{data}(x)\log P_{data}(x)dx]

=argmaxθ[xPdata(x)logPG(x;θ)Pdata(x)dx]= \arg \max_{\theta} [\int_x P_{data}(x)\log \frac{P_{G}(x;\theta)}{P_{data}(x)}dx]

=argmaxθ[KL(PdataPG)]= \arg \max_{\theta} [- KL(P_{data}||P_{G})]

=argminθ[KL(PdataPG)]= \arg \min_{\theta} [KL(P_{data}||P_{G})]

证毕

# 具体步骤

# 搭建生成器

等到我们设计的生成器神经网络GG训练好了之后,我们就可以利用它来生成玲爱风格的卡通头像了。这个神经网络具体应该是什么样子的呢? 我们看下面这张图:

GAN中生成器(Generator)的结构

输入是128维的一个向量, 这些输入的值可以控制输出头像的一些特征,例如头发的颜色,头发的长短, 肤色,性别等等。 make.girls.moe (opens new window)是一个在线的卡通头像生成项目,去玩一下就很容易理解这里神经网络G的输入是什么了。输出构成一张图片,这里的64x64个像素点最后可以拼成一张灰度图。

Tip

这里的输入可以看做是128维空间中的一个点;输出是4096维空间中的一个点; 输入数据我们从一个固定的分布中去采样(例如:128维的高斯分布); 我们期望神经网络(G)的输出数据(图片)的分布PG(x)P_{G}(x)能够和Pdata(x)P_{data}(x)的分布越接近越好。

# 确认目的

训练最终目的: 找到

G=argminGDiv(PG,Pdata)G^* = \arg \min_{G} Div(P_{G}, P_{data})

就是说我们想要找到一个最好的生成器GG^*, 它能够使得PGP_{G}PdataP_{data}的某种Divergence越小越好(前面提到的KLKL是一种具体的Divergence)。 有目标后,一步一步走下去。 自然地,你就会问Div(PG,Pdata)Div(P_{G}, P_{data})我们该怎么算?

# 计算Div(PG,Pdata)Div(P_{G}, P_{data})

任何一种Divergence都是有公式的,我在文章"什么是熵"中有提到。原则上直接带入公式就可以算出来了。但是这里的问题是我们不知道PG,PdataP_{G}, P_{data}的表达式是什么。

一方面,尽管我们有玲爱画的漫画头像库,但是我们也不知道她画的头像服从什么分布, 即PdataP_{data}未知; 另一方面,即使可以利用生成器生成很多的头像,但是我们也不知道这些生成的头像服从什么分布, 即PGP_{G}未知。 那应该如何计算两个分布的Divergence呢? 我们的手上只有两个分布的样本? 可以直接比较样本吗? 答案是肯定的。这件事漫画导师秋風羽織 (opens new window) 可以帮我们大忙。

《半分、青い。》中的秋風羽織正在查看两类图片

我们将玲爱和生成器GG画的头像分别拿去给秋风先生过目,让他说说谁画得更真实。毕竟是玲爱是秋风先生 的得意弟子,因此他总是想也不想就直接说玲爱的画的更好,还总是说生成的差太多。生成器器没有办法,为了得到秋风先生的的夸赞它就会像玲爱学习。这样的结果是机器或许真的可以变得和玲爱一样优秀。

我们的期待

机器真的可以变得和玲爱一样优秀;也就是Pdata(x)P_{data}(x)PG(x)P_{G}(x)可以很接近;也就是Div(PG,Pdata)Div(P_{G},P_{data})很小。

可以看到,在训练生成器的过程中, 有一个举足轻重的角色: 秋風羽織先生。 在训练过程中,判别器可能会产生上万张图片,难道你真的准备让秋風先生去做这么无聊的事?😳

当然不行,因此我们需要另外一个角色来代替秋風先生,那便是判别器。 判别器就做一件事情: 给生成器的生成的头像打低分,给玲爱画的打高分。 开始时判别器的技能当然比不过秋風先生,每当生成器升级之后, 判别器也需要更新技能才能够很好地判断。 好在有很好的资料(训练数据,来自两个分布的样本)可以供它升级(更新网络参数参数)。因此我们希望最终可 以得到一个像秋風先生那般厉害的判别器DD^*

不能够直接算Div(PG,Pdata)Div(P_{G}, P_{data}), 我们只能间接地训练一个判别器来判断一张图片有多像玲爱画的, 很像判别器输出的值就接近1, 不像就接近0; 通过判别器的输出DD,我们再构造一个函数VV, 这个函数就 反映了Div(PG,Pdata)Div(P_{G}, P_{data})

# 训练判别器

要反映Div(PG,Pdata)Div(P_{G}, P_{data}), 就需要有一个判别器网络。 判别器的结构是什么样的呢?

GAN中判别器(Discriminator)的结构

可以看到,判别器的输入是图片,输出是一个数,这个数反映的是这张图有多像玲爱画的。我么如何才能得到像秋風先生那样厉害的判别器DD^*呢? 我们还需要做一件事,使用判别器构造出一个二分类器,即是否判断输入图是否为玲爱的。

分类器的输出:

V=ExPdata[logD(x)]+ExPG[log(1D(x))]V = \mathbb{E}_{x \sim P_{data}}[\log D(x)] + \mathbb{E}_{x \sim P_{G}}[\log (1-D(x))]

解释

VV的表达式可以这样理解,在训练判别器计算分类器输出时:

  1. 当图片xx来自玲爱数据集时,xx带入ExPdata[logD(x)]\mathbb{E}_{x \sim P_{data}}[\log D(x)]计算。
  2. 当图片xx来自生成器数据集时,xx带入ExPG[log(1D(x))]\mathbb{E}_{x \sim P_{G}}[\log (1-D(x))]计算。

可以看到在1中VV和判别器DD相关; 在2中,由于图片xx来自生成器,因此VV还与GG有关,因此VV写作 V(D,G)V(D,G)

于是判别器最优解,

D=argmaxDV(D,G)D^* = \arg \max_{D}V(D,G)

通过调整参数,尝试找到最好的判别器DD^*,它使得VV的值越大越好; 要使VV更大, 判别器就要学会:

  • 给玲爱的画高分,即让D(x)D(x)的值越接近1越好;
  • 给生成的画低分,即让D(x)D(x)的值越接近0越好。

新的问题是:

V(D,G)=ExPdata[logD(x)]+ExPG[log(1D(x))]V(D,G) = \mathbb{E}_{x \sim P_{data}}[\log D(x)] + \mathbb{E}_{x \sim P_{G}}[\log (1-D(x))]

该怎么算?

使用均值的方式去近似期望,于是

V~=1mi=1mlogD(xi)+1mi=1mlog(1D(x~i))\tilde{V}=\frac{1}{m} \sum_{i=1}^{m} \log D\left(x^{i}\right)+\frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(\tilde{x}^{i}\right)\right)

其中

  • xix^i来自玲爱数据集, 标签为1,
  • x~i\tilde{x}^i来自生成器数据集, 标签为0

# 算法总结

始终要记得GAN我们要的是GG, 因此第一步搭建GG网络,第二步是准备训练数据,第三步是训练。 但是训练的过程我们需要一个判别器DD来帮助我们分清好坏。因此我们要顺便训练一个DD。 总结算法如下:

GAN算法

  • 初始化D,GD,G的参数分别为θd,θg\theta_d, \theta_g
  • 循环训练:
    • 训练DD, kk次:
      • 从分布Pdata(x)P_{data}(x)中随机采样mm{x1,x2,...,xm}\{x^1,x^2,...,x^m\}
      • 从已知分布Pprior(z)P_{prior}(z)中随机采样mm{z1,z2,...,zm}\{z^1,z^2,...,z^m\}
      • 使用x~i=G(zi)\tilde{x}^i=G(z^i)获得生成数据 {x~1,x~2,...,x~m}\{\tilde{x}^1, \tilde{x}^2,...,\tilde{x}^m\}
      • 更新DD的参数θd\theta_d来最大化
        • V~=1mi=1mlogD(xi)+1mi=1mlog(1D(x~i))\tilde{V}=\frac{1}{m} \sum_{i=1}^{m} \log D\left(x^{i}\right)+\frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(\tilde{x}^{i}\right)\right)
        • θdθd+ηV~(θd)\theta_{d} \leftarrow \theta_{d}+\eta \nabla \tilde{V}\left(\theta_{d}\right)
    • 训练GG, 11次:
      • 重新从已知分布Pprior(z)P_{prior}(z)中随机采样mm{z1,z2,...,zm}\{z^1,z^2,...,z^m\}
      • 更新GG的参数来最小化
        • V~=1mi=1mlog(1D(x~i))\tilde{V}=\frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(\tilde{x}^{i}\right)\right)
        • θgθgηV~(θg)\theta_{g} \leftarrow \theta_{g}-\eta \nabla \tilde{V}\left(\theta_{g}\right)

GAN在训练判别器时,生成器的参数固定不变;在训练生成器时,判别器的参数固定不变。GAN的核心思想在这篇文章中算是介绍得差不多了。当然最初的GAN也有很多不足之处,例如很难训练,不容易收敛等问题。因此又有很多人对原始的算法进行优化,出现了各种各样衍生的GAN,我们以后有机会在聊一聊其他的GAN。

希望这篇文章能够帮助到你。我是阿梁,我在机器学习之路way2ml (opens new window)分享高效工具, 编程,AI的内容。下次见!

参考:

  1. 首页图片来自JosieArt Deep Dream Generator (opens new window)
  2. 国立台湾大李宏毅老师GAN课程 GAN Lecture 4 (2018): Basic Theory (opens new window)
  3. 维基百科: Generative adversarial network (opens new window)
  4. Medium: Understanding Generative Adversarial Networks (GANs) (opens new window)
  5. GAN示意图: Generative Adversarial Network (GAN) (opens new window)
  6. 维基百科:半分、青い (opens new window)
  7. 步骤示意图: 一个大象放进冰箱的过程 (opens new window)
  8. 火熱的生成對抗網路(GAN),你究竟好在哪裡 (opens new window)
  9. Stackoverflow image with caption - vuepress (opens new window)
  10. html使用简单标签改变字体(加粗、斜体...) (opens new window)
上次更新: 11/24/2021, 10:39:29 PM