风格迁移背后原理及tensorflow实现

http://blog.csdn.net/qq_25737169/article/details/79192211

前言


本文将详细介绍 tf 实现风格迁移的小demo,看完这篇就可以去实现自己的风格迁移了,复现的算法来自论文
Perceptual” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>PerceptualPerceptual LossesforReal−Time” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>LossesforReal−TimeLossesforReal−Time Style” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>StyleStyle Transfer” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>TransferTransfer and” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>andand Super−Resolution” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>Super−ResolutionSuper−Resolution

GitHub代码链接:https://github.com/LDOUBLEV/style_transfer-perceptual_loss 如果感觉有用的话,帮忙给个star吧

本文分为以下部分:
第一节:深度学习在风格迁移上的背后原理;
第二节:风格迁移的代码详解
第三节:总结

图像风格迁移指的是将图像A的风格转换到图像B中去,得到新的图像,取个名字为new B,其中new B中既包含图像B的内容,也包含图像A的风格。如下图所示:
这里写图片描述
从左到右依次为图像A,图像B,图像new B

本文着重介绍基于深度学习技术的风格迁移的原理及其实现,实现使用的工具如下:

  • 框架:Tensorflow 1.4.1
  • 语言:python 2.7
  • 系统:ubuntu 16.04

注:其他条件同样可行,如有问题,欢迎评论、私信

最终效果部分展示:

原图:
这里写图片描述
风格迁移后的图像,右上角那一张明显风格迁移过头了,可以设置style_loss的比例做调整:
这里写图片描述

这里写图片描述
这里写图片描述,最满意的就是左上角那一张了。

第一节:深度学习在风格迁移的背后原理


1.1 背后原理简介

深度学习技术可谓无孔不入,在计算机视觉领域尤为明显,图像分类、识别、定位、超分辨率、转换、迁移、描述等等都已经可以使用深度学习技术实现。其背后的技术可以一言以蔽之:卷积网络具有超强的图像特征提取能力
其中,风格迁移算法的成功,其主要基于以下两点:

  1. 两张图像经过预训练好的分类网络,若提取出的高维特征(high−level” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>high−levelhigh−level)之间的欧氏距离越小,则这两张图像内容越相似
  2. 两张图像经过与训练好的分类网络,若提取出的低维特征(low−level” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>low−levellow−level)在数值上基本相等,则这两张图像越相似,换句话说,两张图像相似等价于二者特征的Gram” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>GramGram矩阵具有较小的弗罗贝尼乌斯范数。

基于这两点,就可以设计合适的损失函数优化网络。

1.2 原理解读

对于深度网络来讲,深度卷积分类网络具有良好的特征提取能力,不同层提取的特征具有不同的含义,每一个训练好的网络都可以视为是一个良好的特征提取器,另外,深度网络由一层层的非线性函数组成,可以视为是复杂的多元非线性函数,此函数完成输入图像到输出的映射,因此,完全可以使用训练好的深度网络作为一个损失函数计算器。

Gram” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>GramGram矩阵的数学形式如下:Gj(x)=A∗AT” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>Gj(x)=A∗ATGj(x)=A∗AT
Gram矩阵实际上是矩阵的内积运算,在风格迁移算法中,其计算的是feature map之间的偏心协方差,在feature map 包含着图像的特征,每个数字表示特征的强度,Gram矩阵代表着特征之间的相关性,因此,Gram矩阵可以用来表示图像的风格,因此可以通过Gram矩阵衡量风格的差异性。

1.3 论文解读

本次主要介绍的是论文:Perceptual Losses for Real-Time Style Transfer and Super-Resolution
直接上图:
这里写图片描述
网络框架分为两部分,其一部分是图像转换网络T” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>TT(image transfrom net)和预训练好的损失计算网络VGG-16(loss network),图像转换网络T” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>TT以内容图像x” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>xx为输入,输出风格迁移后的图像y′” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>y′y′,随后内容图像yc” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>ycyc(也即是x” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>xx),风格图像ys” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>ysys,以及y′” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>y′y′输入vgg-16计算特征,损失计算如下:
内容损失:lfeatφ;j(y;y)=1CjHjWj||φj(y′)−φj(y)||2″ role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>lφ;jfeat(y;y)=1CjHjWj||φj(y′)−φj(y)||2lfeatφ;j(y;y)=1CjHjWj||φj(y′)−φj(y)||2, 其中φ” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>φφ代表深度卷积网络VGG-16

感知损失如下:lstyleφ;j(y;y)=||Gj(y′)−Gj(y)||F2″ role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>lφ;jstyle(y;y)=||Gj(y′)−Gj(y)||2Flstyleφ;j(y;y)=||Gj(y′)−Gj(y)||F2,其中G是Gram矩阵,计算过程如下:

Gjφ(x)c′,c=||Gjφ(y′)−Gjφ(y)||” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>Gφj(x)c′,c=||Gφj(y′)−Gφj(y)||Gjφ(x)c′,c=||Gjφ(y′)−Gjφ(y)||

总损失定义如下:Losstotal=γ1lfeat+γ2lstyle” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>Losstotal=γ1lfeat+γ2lstyleLosstotal=γ1lfeat+γ2lstyle

其中图像转换网络T定义如下图:
这里写图片描述

网络结构三个卷积层后紧接着5个残差块,然后两个上采样(邻近插值的方式),最后一个卷积层,第一层和最后一层的卷积核都是9×9,其余均为3×3。每个残差块中包含两层卷积。

第二节:代码详解


本次实验主要基于tf的slim模块,slim封装的很好,调用起来比较方便。接下来分为网络结构,损失函数,以及训练部分分别做介绍。

2.1 网络结构

slim = tf.contrib.slim # 定义卷积,在slim中传入参数 def arg_scope(weight_decay=0.0005): with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.conv2d_transpose], activation_fn=None, weights_regularizer=slim.l2_regularizer(weight_decay), biases_initializer=tf.zeros_initializer()): with slim.arg_scope([slim.conv2d, slim.conv2d_transpose], padding='SAME') as arg_sc: return arg_sc 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

接下来就是图像转换网络结构部分,仿照上图,不过这里有一个trick,就是在输入之前对图像做padding,经过网络后再把padding的部分去掉,防止迁移后出现边缘效应。

def gen_net(imgs, reuse, name, is_train=True): imgs = tf.pad(imgs, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT') with tf.variable_scope(name, reuse=reuse) as vs: # encoder : three convs layers out1 = slim.conv2d(imgs, 32, [9, 9], scope='conv1') out1 = relu(instance_norm(out1)) out2 = slim.conv2d(out1, 64, [3, 3], stride=2, scope='conv2') out2 = instance_norm(out2) # out2 = relu(img_scale(out2, 0.5)) out2 = slim.conv2d(out2, 128, [3, 3], stride=2, scope='conv3') out2 = instance_norm(out2) # out2 = relu(img_scale(out2, 0.5)) # transform out3 = res_module(out2, 128, name='residual1') out3 = res_module(out3, 128, name='residual2') out3 = res_module(out3, 128, name='residual3') out3 = res_module(out3, 128, name='residual4') # decoder out4 = img_scale(out3, 2) out4 = slim.conv2d(out4, 64, [3, 3], stride=1, scope='conv4') out4 = relu(instance_norm(out4)) # out4 = img_scale(out4, 128) out4 = img_scale(out4, 2) out4 = slim.conv2d(out4, 32, [3, 3], stride=1, scope='conv5') out4 = relu(instance_norm(out4)) # out4 = img_scale(out4, 256) out = slim.conv2d(out4, 3, [9, 9], scope='conv6') out = tf.nn.tanh(instance_norm(out)) variables = tf.contrib.framework.get_variables(vs) out = (out + 1) * 127.5 height = out.get_shape()[1].value # if is_train else tf.shape(out)[0] width = out.get_shape()[2].value # if is_train else tf.shape(out)[1] out = tf.image.crop_to_bounding_box(out, 10, 10, height-20, width-20) # out = tf.reshape(out, imgs_shape) return out, variables
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45

其中instance_norm是归一化部分[5],res_module是残差块,image_scale是采样部分,scale因子是2表示上采样,特征图扩大2倍:

def img_scale(x, scale): weight = x.get_shape()[1].value height = x.get_shape()[2].value try: out = tf.image.resize_nearest_neighbor(x, size=(weight*scale, height*scale)) except: out = tf.image.resize_images(x, size=[weight*scale, height*scale]) return out # net = slim.conv2d(net, 4096, [1, 1], scope='fc7') def res_module(x, outchannel, name): with tf.variable_scope(name_or_scope=name): out1 = slim.conv2d(x, outchannel, [3, 3], stride=1, scope='conv1') out1 = relu(out1) out2 = slim.conv2d(out1, outchannel, [3, 3], stride=1, scope='conv2') out2 = relu(out2) return x+out2 def instance_norm(x): epsilon = 1e-9 mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon)))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

2.2图的构建

此部分流程:读取训练数据(coco数据集)−−” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>−−−−读取风格图像−−” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>−−−−并输入图像转换网络计算出转换后的图像gen_img−−” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>−−−−原始图像,风格图像,转换后的图像一同输入VGG计算loss−−” role=”presentation” style=”margin: 0px; padding: 0px; box-sizing: border-box; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; outline: 0px; position: relative;”>−−−−VGG权重加载

 def build_model(self): # data_path = '/home/liu/Tensorflow/BEGAN/Data/celeba/img_align_celeba' data_path = '/home/liu/Downloads/train2014' # 加载训练数据(coco数据集) imgs = load_data.get_loader(data_path, self.batch_size, self.img_size) # 加载风格图像 style_imgs = load_style_img() with slim.arg_scope(model.arg_scope()): # 图像转换网络 gen_img, variables = model.gen_net(imgs, reuse=False, name='transform') with slim.arg_scope(vgg.vgg_arg_scope()): # 对图像做处理 gen_img_processed = [load_data.img_process(image, True) for image in tf.unstack(gen_img, axis=0, num=self.batch_size)] # f表示vgg每段卷积的特征图输出, exclude是VGG不需要加载的变量的名字 f1, f2, f3, f4, exclude = vgg.vgg_16(tf.concat([gen_img_processed, imgs, style_imgs], axis=0)) gen_f, img_f, _ = tf.split(f3, 3, 0) # 计算损失 content loss 和 style loss content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(tf.size(gen_f)) style_loss = model.styleloss(f1, f2, f3, f4) # load vgg model vgg_model_path = '/home/liu/Tensorflow-Project/temp/model/vgg_16.ckpt' vgg_vars = slim.get_variables_to_restore(include=['vgg_16'], exclude=exclude) # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6']) init_fn = slim.assign_from_checkpoint_fn(vgg_model_path, vgg_vars) init_fn(self.sess) # tf.initialize_variables(var_list=vgg_init_var) print 'vgg s weights load done' self.gen_img = gen_img self.global_step = tf.Variable(0, name="global_step", trainable=False) self.content_loss = content_loss self.style_loss = style_loss*100 # 100是随意设置的,可以调整控制风格迁移的程度 self.loss = self.content_loss + self.style_loss self.opt = tf.train.AdamOptimizer(0.0001).minimize(self.loss, global_step=self.global_step, var_list=variables) all_var = tf.global_variables() # init_var = [v for v in all_var if 'beta' in v.name or 'global_step' in v.name or 'Adam' in v.name] init_var = [v for v in all_var if 'vgg_16' not in v.name] init = tf.variables_initializer(var_list=init_var) self.sess.run(init) self.save = tf.train.Saver(var_list=variables)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

训练部分代码:

 def train(self): print ('start to training') coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): # start_time = time.time() _, loss, step, cl, sl = self.sess.run([self.opt, self.loss, self.global_step, self.content_loss, self.style_loss]) if step%100 == 0: gen_img = self.sess.run(self.gen_img) if not os.path.exists('gen_img'): os.mkdir('gen_img') save_img.save_images(gen_img, './gen_img/{0}.jpg'.format(step/100)) print ('[{}/40000],loss:{}, content:{},style:{}'.format(step, loss, cl, sl)) if step % 2000 == 0: if not os.path.exists('model_saved_s'): os.mkdir('model_saved_s') self.save.save(self.sess, './model_saved_s/wave{}.ckpt'.format(step/2000)) # 训练40000次就停止,大概2epoch if step >= 40000: break except tf.errors.OutOfRangeError: self.save.save(sess, os.path.join(os.getcwd(), 'fast-style-model.ckpt-done')) finally: coord.request_stop() coord.join(threads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

总结:


本文浮现的论文仍然有一些不足之处,比如根据一个风格图像训练一个model只能风格化此种图像,要风格化很多种图像就要训练不同的model,不过在后来的论文中已经得到了解决,以后有时间我会继续复现。



发表评论

电子邮件地址不会被公开。 必填项已用 * 标注

*

您可以使用这些 HTML 标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>

(Spamcheck Enabled)

最新评论