当前位置: 代码迷 >> 综合 >> TensorFlow 11——ch08-GAN和DCGAN入门
  详细解决方案

TensorFlow 11——ch08-GAN和DCGAN入门

热度:68   发布时间:2023-09-26 21:28:14.0

TensorFlow 11——ch08-GAN和DCGAN入门
代码:https://github.com/MONI-JUAN/Tensorflow_Study/ch08-GAN和DCGAN入门

目录

    • 一、基本概念
      • 1.GAN 的原理
      • 2.交叉熵损失
      • 3.DCGAN的原理
    • 二、生成MNIST图像
      • 1.下载数据集
      • 2.训练
      • 3.训练结果
    • 三、使用自己的数据集训练
      • 1.下载数据集
      • 2.训练模型
      • 3.测试模型
      • 4.测试效果

一、基本概念

GAN 的全称为 Generative Adversarial Networks,意为对抗生成网络。

DCGAN 将 GAN 的概念扩展到卷积神经网络中,可以生成质量较高的图片样本 。

1.GAN 的原理

有两个网络,一个是生成网络G(Generator),一个是判别网络D(Discriminator)

  • G:通过噪声z生成图片,记作 G(z) ;
  • D:判断图片是不是”真实的“,输入的x,输出 D(x) 代表是真实图片的概率

训练过程:G尽量生成真实图片去欺骗D,D尽量区分G生成的图片和真实图片。

2.交叉熵损失

V(D,G)=Ex?Pdata (x)[ln?D(x)]+Ez?pz(z)[ln?(1?D(G(z)))]V(D, G)=E_{x \sim P_{\text {data }}(x)}[\ln D(x)]+E_{z \sim p_{z}(z)}[\ln (1-D(G(z)))] V(D,G)=Ex?Pdata ?(x)?[lnD(x)]+Ez?pz?(z)?[ln(1?D(G(z)))]

  • 左边x部分代表真实图片,右边G(z)是生成的图片;
  • D(x) 和 D(G(z)) 都是判断的概率;
  • 生成网络 G 希望 D(G(z)) 变大,V(D, G)越大越好;
  • 判别网络 D 希望 D(x) 变大,V(D, G)越小越好;

3.DCGAN的原理

DCGAN 的全称是 Deep Convolutional Generative Adversarial Networks,即深度卷积对抗生成网络。从名字上来看,是在 GAN 的基础上增加深度卷积网络结构,专门生成图像样本。

事实上,GAN 并没再对D、 G 的具体结构做出任何限制 。DCGAN 中的 D、 G 的含义以及损失都和原始 GAN 中完全一,但是它在 D 和 G 中采用了较为特殊的结构,以便对图片进行高效建模。

DCGAN 中 G 的网络结构:

TensorFlow 11——ch08-GAN和DCGAN入门

  • 不采用池化层,D中用补偿(stride)的卷积代替池化;
  • 在 G、 D 中均使用 Batch Normalization 帮助模型收敛。
  • 在 G 中,激活函数除了最后一层都使用 ReLU 函数,而最后一层使用 tanh 函数。
  • 在 D 中,激活函数都使用 Leaky ReLU 作为激活函数。

TensorFlow 11——ch08-GAN和DCGAN入门

TensorFlow 11——ch08-GAN和DCGAN入门
TensorFlow 11——ch08-GAN和DCGAN入门

二、生成MNIST图像

1.下载数据集

用脚本下载(可能会下载失败,我也不知道为什么每次都失败)

python download.py mnist

或者百度云

链接:https://pan.baidu.com/s/1l-IHrXYvt4M8kj_C-Blklw
提取码:kgrw

这个数据集和chapter 01 的一样:https://blog.csdn.net/qq_34451909/article/details/108264641

TensorFlow 11——ch08-GAN和DCGAN入门

2.训练

python main.py --dataset mnist --input_height=28 --output_height=28 --train

TensorFlow 11——ch08-GAN和DCGAN入门
TensorFlow 11——ch08-GAN和DCGAN入门

3.训练结果

每过100步会保存一张当前训练情况的图

TensorFlow 11——ch08-GAN和DCGAN入门

对比一下 0_99 和 1_106,才训练了一千步左右,已经很有数字的样子了。

TensorFlow 11——ch08-GAN和DCGAN入门

看一下书中25个epoch,也就是2.5w步之后的图像:

TensorFlow 11——ch08-GAN和DCGAN入门

三、使用自己的数据集训练

1.下载数据集

faces.zip

链接:https://pan.baidu.com/s/1l-IHrXYvt4M8kj_C-Blklw
提取码:kgrw

解压faces.zip ,把 anime放进 data 目录

2.训练模型

python main.py --input_height 96 --input_width 96 \ # 截取中心96*96--output_height 48 --output_width 48 \ # 缩放到48*48--dataset anime --crop -–train \ # 需要执行训练--epoch 300 --input_fname_pattern "*.jpg" # 找出所有.jpg训练
python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset anime --crop -–train --epoch 300 --input_fname_pattern "*.jpg"

这已经是训练3.7小时候的结果了,电脑太渣了

TensorFlow 11——ch08-GAN和DCGAN入门

TensorFlow 11——ch08-GAN和DCGAN入门

TensorFlow 11——ch08-GAN和DCGAN入门

对比训练模型:

  • 如果是 mnist 数据集:

    if config.dataset == 'mnist':
    # Update D network
    _, summary_str = self.sess.run([d_optim, self.d_sum],
    feed_dict={
           self.inputs: batch_images,self.z: batch_z,self.y:batch_labels,
    })
    self.writer.add_summary(summary_str, counter)# Update G network
    _, summary_str = self.sess.run([g_optim, self.g_sum],
    feed_dict={
          self.z: batch_z, self.y:batch_labels,
    })
    self.writer.add_summary(summary_str, counter)# Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
    _, summary_str = self.sess.run([g_optim, self.g_sum],
    feed_dict={
           self.z: batch_z, self.y:batch_labels })
    self.writer.add_summary(summary_str, counter)errD_fake = self.d_loss_fake.eval({
          self.z: batch_z, self.y:batch_labels
    })
    errD_real = self.d_loss_real.eval({
          self.inputs: batch_images,self.y:batch_labels
    })
    errG = self.g_loss.eval({
          self.z: batch_z,self.y: batch_labels
    })
    
  • 如果是其他数据:

    else:# Update D network_, summary_str = self.sess.run([d_optim, self.d_sum],feed_dict={
           self.inputs: batch_images, self.z: batch_z })self.writer.add_summary(summary_str, counter)# Update G network_, summary_str = self.sess.run([g_optim, self.g_sum],feed_dict={
           self.z: batch_z })self.writer.add_summary(summary_str, counter)# Run g_optim twice to make sure that d_loss does not go to zero (different from paper)_, summary_str = self.sess.run([g_optim, self.g_sum],feed_dict={
           self.z: batch_z })self.writer.add_summary(summary_str, counter)errD_fake = self.d_loss_fake.eval({
           self.z: batch_z })errD_real = self.d_loss_real.eval({
           self.inputs: batch_images })errG = self.g_loss.eval({
          self.z: batch_z})
    

3.测试模型

python main.py --input_height 96 --input_width 96 \--output_height 48 --output_width 48 \--dataset anime --crop
python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset anime --crop

main.py中的 OPTION 可以设置 0-4 ,在 utils.py 中的函数 visualize() 中可以看到不同的可视化选项,可以自己设置这个OPTION

# Below is codes for visualization
OPTION = 0
visualize(sess, dcgan, FLAGS, OPTION)

4.测试效果

因为默认都是生成到samples这个文件夹,比较乱,我改了一下路径,生成到五个文件夹。

又因为模型训练的程度不够,才一千多不就已经训练了两个半小时了,只能凑合看看。

OPTION = 0:用模型生成一张10*10的图片
OPTION = 1:生成100张10*10的图片,都差不多样子
OPTION = 2:生成100张10*10的图片,都差不多样子
OPTION = 3:生成100张10*10的图片组成的动画
OPTION = 4:生成100张10*10的图片组成的动画,最后汇合到一个gif
  • OPTION = 0

TensorFlow 11——ch08-GAN和DCGAN入门

  • OPTION = 1
    TensorFlow 11——ch08-GAN和DCGAN入门

  • OPTION = 2

    TensorFlow 11——ch08-GAN和DCGAN入门

  • OPTION = 3

    TensorFlow 11——ch08-GAN和DCGAN入门

  • OPTION = 4

TensorFlow 11——ch08-GAN和DCGAN入门

随便放了个gif上来看

TensorFlow 11——ch08-GAN和DCGAN入门

好吧,训练的太少看不出效果

  相关解决方案