博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
CycleGan总结及简易复现
阅读量:4131 次
发布时间:2019-05-25

本文共 4830 字,大约阅读时间需要 16 分钟。

CycleGan总结及代码简易复现

CycleGan论文地址:

简介

2017年以前的GAN都是通过配对好的一组图片去训练模型的,但是想要获得大量的成对图片比较难,而CycleGan是无监督生成对抗网络,其实是做的是一个domain adaption的工作,可以通过一些不配对的两组图片利用生成器-判别器模型和计算它的循环损失实现领域的自适应。即把原始图像(如马)导入生成器G1(马→斑马)生成目标图像(斑马),再把目标图像当作F(斑马→马)的输入,计算生成新的图像(马~)与最初的原始图像(马)的差别,即损失,让该损失尽可能地小即能确保生成器不会生成与原始图像无关的图片。如下图所示:

在这里插入图片描述
所以总的损失函数就是L = 两个生成器的损失(G1_loss + F_loss)+两个循环损失(cycle1_loss + cycle2_loss)+ 两个identity损失(即往G1输入斑马的图片,计算生成后的斑马图片与输入的真实斑马图片的差距,同理往F输入马的图片,且此项有时可以省去来提高计算效率)
生成器的损失用MSE,循环损失与identity损失用L1函数。

拓展: 回归损失函数的对比:L1 loss, L2 loss(MSE)以及Smooth L1 Loss的对比

L1 loss函数:指的是模型预测值f(x)和真实值y之间距离的均值,公式为:在这里插入图片描述

图像:在这里插入图片描述
由图像可知:
①当损失很小时,其梯度比较大,不利于模型的训练和收敛
②无论对于什么样的输入值,其梯度都是固定的,所以不会产生梯度爆炸的问题,也就是对偏离真实样本的比较大的值不怎么敏感,有利于模型的稳定。
③在y-f(x)= 0 处不可导,可能影响收敛

L2 loss函数:模型预测值f(x) 与真实样本值y 之间差值平方的均值。

公式:在这里插入图片描述
图像:在这里插入图片描述
由图可知:
①函数在所有输入范围内都是连续的
②随着损失的减小,梯度也在减小,这有利于模型的快速收敛
③对离群点比较敏感,受其影响比较大

Smooth L1 loss函数:

在Faster-Rcnn和SSD中都用到了该函数。
公式:
x为真实值与预测值的差值

图像:在这里插入图片描述

可以看出Smooth loss函数为前两者的结合,取其精华去其糟粕。

Smooth L1的优点;

①相比于L1损失函数,可以收敛得更快。
②相比于L2损失函数,对离群点、异常值不敏感,梯度变化相对更小,训练时不容易跑飞。

CycleGan网络结构

在这里插入图片描述

生成器的网络可简化为:

一个卷积块

两个下采样块
九个残差模块
2个上采样模块
一个卷积块(output_channel = 3)
经过tanh模块(将特征图的值归为-1至1之间)

代码如下:

import torchimport torch.nn as nnclass ConvBlock(nn.Module):    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):        super().__init__()        self.conv = nn.Sequential(            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)            if down            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),            nn.InstanceNorm2d(out_channels),            nn.ReLU(inplace=True) if use_act else nn.Identity()        )    def forward(self, x):        return self.conv(x)class ResidualBlock(nn.Module):    def __init__(self, channels):        super().__init__()        self.block = nn.Sequential(            ConvBlock(channels, channels, kernel_size=3, padding=1),            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),        )    def forward(self, x):        return x + self.block(x)class Generator(nn.Module):    def __init__(self, img_channels, num_features = 64, num_residuals=9):        super().__init__()        self.initial = nn.Sequential(            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),            nn.InstanceNorm2d(num_features),            nn.ReLU(inplace=True),        )        self.down_blocks = nn.ModuleList(            [                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),            ]        )        self.res_blocks = nn.Sequential(            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]        )        self.up_blocks = nn.ModuleList(            [                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),            ]        )        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")    def forward(self, x):        x = self.initial(x)        for layer in self.down_blocks:            x = layer(x)        x = self.res_blocks(x)        for layer in self.up_blocks:            x = layer(x)        return torch.tanh(self.last(x))

判别器的网络

同理可得:总共5层卷积层,目标是生成特征图里面的值为0-1之间,方便待会跟生成器网络生成的图进行损失计算。代码如下:

import torchimport torch.nn as nnclass Block(nn.Module):    def __init__(self, in_channels, out_channels, stride):        super().__init__()        self.conv = nn.Sequential(            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),            nn.InstanceNorm2d(out_channels),            nn.LeakyReLU(0.2, inplace=True),        )    def forward(self, x):        return self.conv(x)class Discriminator(nn.Module):    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):        super().__init__()        self.initial = nn.Sequential(            nn.Conv2d(                in_channels,                features[0],                kernel_size=4,                stride=2,                padding=1,                padding_mode="reflect",            ),            nn.LeakyReLU(0.2, inplace=True),        )        layers = []        in_channels = features[0]        for feature in features[1:]:            layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))            in_channels = feature        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))        self.model = nn.Sequential(*layers)    def forward(self, x):        x = self.initial(x)        return torch.sigmoid(self.model(x))

训练模块和载入数据集的模块可以仿照原论文进行编写。

转载地址:http://utfvi.baihongyu.com/

你可能感兴趣的文章
JS中的call、apply、bind方法
查看>>
JavaScript中typeof
查看>>
Javascript闭包(Closure)
查看>>
CSS3中的变形与动画(上)【2D】 Transform 和 Transition
查看>>
CSS3中的变形与动画(下)
查看>>
CSS3 布局样式
查看>>
CSS3 Media Queries 与Responsive 设计
查看>>
CSS3的flexbox布局
查看>>
前端面试题系列
查看>>
HTML优化技巧
查看>>
Javascript模块化编程
查看>>
Bootstrap(五) 导航条、分页导航
查看>>
Bootstrap(六) 其它内置组件
查看>>
HTML DOM querySelector() 方法
查看>>
js 事件冒泡和事件捕获的区别
查看>>
Web前端性能优化(一)减少Http请求
查看>>
Web前端性能优化(二)使用内容分发网络
查看>>
Web前端性能优化(四)压缩组件
查看>>
Web前端性能优化(五)网站样式和脚本
查看>>
Web前端性能优化(六)减少DNS查找、避免重定向
查看>>