我这个是在TF2.0的版本上实现的,废话不多说,相信大家都知道了什么是gan。
或者做个简单的比喻:
生成器网络:生成虚假的图片数据判别器网络:用来判别虚假图片的网络两个网络的目的都只有一个,就是打败对方,然后不断的成长,这就是对抗的原理。
还是直接导入相关的工具包
import tensorflow
as tf
from tensorflow
import keras
from tensorflow
.keras
import layers
, Sequential
, optimizers
from PIL
import Image
构建相应的生成器和判别器
class Generator(keras
.Model
):
"""
生成器:主要是通过输入的噪声进行生成fake的图片,然后将生成的图片传入到判别器检测,直接使用了逆卷积(不会的自己去找公式,然后直接逆推,就可以退出下面的参数的设置),目的就是生成指定类型的图片数据
"""
def __init__(self
):
super(Generator
, self
).__init__
()
self
.fc
= layers
.Dense
(3*3*512)
self
.conv1
= layers
.Conv2DTranspose
(256, 3, 3, 'valid')
self
.bn1
= layers
.BatchNormalization
()
self
.conv2
= layers
.Conv2DTranspose
(128, 5, 2, 'valid')
self
.bn2
= layers
.BatchNormalization
()
self
.conv3
= layers
.Conv2DTranspose
(32, 4, 1, 'valid')
self
.bn3
= layers
.BatchNormalization
()
self
.conv4
= layers
.Conv2DTranspose
(1, 5, 1, 'valid')
def call(self
, x
, training
= None):
x
= self
.fc
(x
)
x
= tf
.reshape
(x
, [-1, 3, 3, 512])
x
= tf
.nn
.leaky_relu
(x
)
x
= tf
.nn
.leaky_relu
(self
.bn1
(self
.conv1
(x
),training
= training
))
x
= tf
.nn
.leaky_relu
(self
.bn2
(self
.conv2
(x
),training
= training
))
x
= tf
.nn
.leaky_relu
(self
.bn3
(self
.conv3
(x
),training
= training
))
x
= self
.conv4
(x
)
x
= tf
.tanh
(x
)
return x
class Discriminator(keras
.Model
):
"""
判别器:根据传入的图片数据进行判断真假的概率,然后互相进行学习,更新权重,从而找到最优的权重。
"""
def __init__(self
):
super(Discriminator
, self
).__init__
()
self
.conv1
= layers
.Conv2D
(64, kernel_size
=4, strides
=2, padding
='valid')
self
.conv2
= layers
.Conv2D
(128, kernel_size
=3, strides
=2, padding
='valid')
self
.bn2
= layers
.BatchNormalization
()
self
.conv3
= layers
.Conv2D
(256, kernel_size
=4, strides
=2, padding
='valid')
self
.bn3
= layers
.BatchNormalization
()
self
.flatten
= layers
.Flatten
()
self
.fc
= layers
.Dense
(1)
def call(self
, x
, training
= None):
x
= tf
.nn
.leaky_relu
(self
.conv1
(x
))
x
= tf
.nn
.leaky_relu
(self
.bn2
(self
.conv2
(x
), training
= training
))
x
= tf
.nn
.leaky_relu
(self
.bn3
(self
.conv3
(x
), training
= training
))
x
= self
.flatten
(x
)
logits
= self
.fc
(x
)
return logits
训练判别器网络
def loss_d(g
, d
, batch_z
, batch_x
, training
):
fake_img
= g
(batch_z
, training
)
d_fake_logits
= d
(fake_img
, training
)
d_real_logits
= d
(batch_x
, training
)
d_real_loss
= loss_ones
(d_real_logits
)
d_fake_loss
= loss_zeros
(d_fake_logits
)
loss
= d_real_loss
+ d_fake_loss
return loss
训练生成器网络
def loss_g(g
, d
, batch_z
, training
):
fake_img
= g
(batch_z
, training
)
d_fake_logits
= d
(fake_img
, training
)
loss
= loss_ones
(d_fake_logits
)
return loss
对整个网络进行训练
def main():
z_dim
= 100
epochs
= 100
batch_size
= 32
learning_rate
= 0.002
training
= True
(x_train
, _
),(_
, _
) = keras
.datasets
.mnist
.load_data
()
x_train
= tf
.expand_dims
(x_train
, axis
=3)
x_train
= 2 * tf
.cast
(x_train
, dtype
=tf
.float32
) / 255. - 1
train_db
= tf
.data
.Dataset
.from_tensor_slices
(x_train
)
train_db
= train_db
.batch
(batch_size
)
sample
= next(iter(train_db
))
print(sample
.shape
, tf
.reduce_max
(sample
).numpy
(), tf
.reduce_min
(sample
).numpy
())
train_db
= train_db
.repeat
()
db_iter
= iter(train_db
)
generator
= Generator
()
generator
.build
(input_shape
= (None, z_dim
))
discriminator
= Discriminator
()
discriminator
.build
(input_shape
=(None, 28, 28, 1))
g_optimizer
= optimizers
.Adam
(learning_rate
=learning_rate
, beta_1
=0.5)
d_optimizer
= optimizers
.Adam
(learning_rate
=learning_rate
, beta_1
=0.5)
for epoch
in range(epochs
):
batch_z
= tf
.random
.uniform
([batch_size
, z_dim
], minval
=-1. , maxval
=1.)
batch_x
= next(db_iter
)
with tf
.GradientTape
() as tape
:
d_loss
= loss_d
(generator
, discriminator
, batch_z
, batch_x
, training
)
grads
= tape
.gradient
(d_loss
, discriminator
.trainable_variables
)
d_optimizer
.apply_gradients
(zip(grads
, discriminator
.trainable_variables
))
with tf
.GradientTape
() as tape
:
g_loss
= loss_g
(generator
, discriminator
, batch_z
, training
)
grads
= tape
.gradient
(g_loss
, generator
.trainable_variables
)
g_optimizer
.apply_gradients
(zip(grads
, generator
.trainable_variables
))
if epoch
% 10 == 0:
print(epoch
, 'd_loss: ',d_loss
, "g_loss: ", g_loss
)
最后呢,实现的结果等着大家去实现,我无法进行,因为我的电脑只有CPU,没有GPU的,所以训练起来异常的缓慢。
如果没有错误,整体的设想的结果应该是下面这个样子的
生成的噪声图
训练n多遍之后的结果应该理论上是这个样子的
转载请注明原文地址:https://blackberry.8miu.com/read-597.html