一、什么是对抗网络:

生成式对抗网络(Generative adversarial network, GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

二、对抗网络能干什么:

(1)数据生成,主要指图像生成。图像生成:基于训练的模型,生成类似于训练集的新的图片。

(2)图像数据增强:增强图像中的有用信息,改善图像的视觉效果。

(3)图像外修复:从受限输入图像生成具有语义意义结构的新的视觉和谐内容来扩展图像边界。

(4)图像超分辨率:由一幅低分辨率图像或图像序列恢复出高分辨率图像。

(5)图像风格迁移:通过某种方法,把图像从原风格转换到另一个风格,同时保证图像内容没有变化。

(6)语音合成

意义:GAN网络可以帮助我们建立模型,相比于在已有模型上进行参数更新的传统网络,更具研究价值。

三、对抗网络由哪些部分组成:

(1)生成器(Generator):生成器要不断优化自己生成的数据让判别器判断不出来。

(2)判别器(Discriminator):判别器要进行优化让自己判断得更准确。

二者关系形成对抗,因此叫生成式对抗网络。

接下来我简述下,对抗网络的过程是怎么走的,这是重点:

先给大家说下什么是BCE_LOSS(二元交叉熵):

他是一个专注与做二分类任务的损失函数,目的是求损失,梯度更新,在这里,里面weight(权重参数)不用写。建议大家去搜下这损失函数。

第一步:

先生成一组标签分别是0和1,稍后用作BCE_LOSS损失的输入。

第二步:

训练判别器

会先把真实数据送入判别模型,会返回一个值,然后我们把这个值,和真实值打的标签1求BCE_LOSS损失。

然后把假的的数据(噪音)送入生成模型,也会返回一个值,我们再把这个值,和假的标签0求BECE_LOSS损失。

最后把真实值损失和假数据的损失加到一起,一起求梯度,进行更新。

第三步:

训练生成器

因为我们在判别阶段,已经更新了生成器的参数,所以可以直接再次更新(其实就是参数共享)。

最后:可以根据对抗效果设置迭代次数。

可以参考图片理解,如果不行可以翻译看代码

我下面有一个用对抗网络生成图片的代码,大家可以参考参考:

可以直接复制到pycharm,需要改下手写体数据路径:

 
import os import torch import torch.nn as nn import torchvision from
torchvision import transforms from torchvision.utils import save_image from
torch.utils.data import DataLoader # 设备配置 device = torch.device('cuda' if
torch.cuda.is_available() else 'cpu') # 优选超参数 latent_size = 64 hidden_size =
256 image_size = 784 num_epochs = 200 bathc_size = 100 sample_dir = 'samples' #
如果不存在目录创建一个目录 if not os.path.exists(sample_dir): os.makedirs(sample_dir) # 图像处理
transform = transforms.Compose([ transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]) ]) # 加载手写体数据集 mnist =
torchvision.datasets.MNIST(
root=r'C:\Users\qiuhongsen\PycharmProjects\pythonProject\NLP--2\My self
dai\tensorflow1\MNIST_data', train=True, transform=transform, download=True) #
数据加载器 data_loader = DataLoader(dataset=mnist, batch_size=bathc_size,
shuffle=True) # 判别器 D = nn.Sequential( nn.Linear(image_size, hidden_size),
nn.LeakyReLU(0.2), nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1), nn.Sigmoid() ) # 生成器 G = nn.Sequential(
nn.Linear(latent_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size,
hidden_size), nn.ReLU(), nn.Linear(hidden_size, image_size), nn.Tanh() ) #
二分类交叉熵损失函数 loss = nn.BCELoss() d_optimizer = torch.optim.Adam(D.parameters(),
lr=0.0002) g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002) def
denorm(x): out = (x + 1) / 2 return out.clamp(0, 1) # 优化器初始化 def reset_grad():
d_optimizer.zero_grad() g_optimizer.zero_grad() # 开始训练 total_step =
len(data_loader) for epoch in range(num_epochs): for i, (images, _) in
enumerate(data_loader): images = images.reshape(bathc_size, -1).to(device) #
创建标签,稍后用作BCE丢失的输入 real_labels = torch.ones(bathc_size, 1).to(device)
fake_labels = torch.zeros(bathc_size, 1).to(device) #
================================================================== # # 训练判断器 #
# ================================================================== # #
使用真实计算图计算BCE_LOSS # 损失的第二项总是1因为真实的标签是1 oupouts = D(images) d_loss_real =
loss(oupouts, real_labels) real_score = oupouts # 使用假的计算图计算BCE_LOSS #
损失的第一项总是0,因为假的标签是0 z = torch.randn(bathc_size, latent_size).to(device)
fake_images = G(z) oupout = D(fake_images) d_loss_fake = loss(oupout,
fake_labels) fake_score = oupout # backprop 和 optimizer d_loss = d_loss_real +
d_loss_fake reset_grad() d_loss.backward() d_optimizer.step() #
================================================================== # # 训练生成器 #
# ================================================================== # #
用假图计算损失 z = torch.randn(bathc_size, latent_size).to(device) fake_images = G(z)
oupouts = D(fake_images) g_loss = loss(oupout, real_labels) reset_grad()
g_loss.backward() g_optimizer.step() if (i + 1) % 200 == 0:
print('Epoch[{}/{}],Step[{}/{}],d_loss{:.4f},g_loss{:.4f},D(x):{.2f},D(G(z)):{.2f}'
.format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))

这是真实的图片↑:

※大家有兴趣可以跑跑试试,迭代200次。

下面是迭代50词的图片↓:

 

技术
下载桌面版
GitHub
Microsoft Store
SourceForge
Gitee
百度网盘(提取码:draw)
云服务器优惠
华为云优惠券
京东云优惠券
腾讯云优惠券
阿里云优惠券
Vultr优惠券
站点信息
问题反馈
邮箱:[email protected]
吐槽一下
QQ群:766591547
关注微信