Source code for stylish.transform

# :coding: utf-8

"""The image transformation network is a deep residual convolutional neural
network parameterized by weights.

The network body consists of five residual blocks. All non-residual
convolutional layers are followed by an instance normalization and ReLU
non-linearities with the exception of the output layer, which instead uses a
scaled "tanh" to ensure that the output image has pixels in the range [0, 255].
Other than the first and last layers which use 9 × 9 kernels, all convolutional
layers use 3 × 3 kernels.

.. seealso::

    Johnson et al. (2016). Perceptual losses for real-time style transfer and
    superresolution. `CoRR, abs/1603.08155
    <https://arxiv.org/abs/1603.08155>`_.

.. seealso::

    Ulyanov et al. (2017). Instance Normalization: The Missing Ingredient for
    Fast Stylization. `CoRR, abs/1607.08022
    <https://arxiv.org/abs/1607.08022>`_.

"""

import tensorflow as tf


[docs]def network(input_node): """Apply the image transformation network. The last node of the graph will be returned. The network will be applied to the current :term:`Tensorflow` graph. Example:: >>> g = tf.Graph() >>> with g.as_default(), tf.Session() as session: ... ... ... network(input_node) *input_node* should be a 4-D Tensor representing a batch list of images. It will be the input of the network. """ node = conv2d_layer( input_node, "conv1", in_channels=3, out_channels=32, kernel_size=9, strides=1, activation=True ) node = conv2d_layer( node, "conv2", in_channels=32, out_channels=64, kernel_size=3, strides=2, activation=True ) node = conv2d_layer( node, "conv3", in_channels=64, out_channels=128, kernel_size=3, strides=2, activation=True ) node = residual_block( node, "residual_block_1", in_channels=128, out_channels=128, kernel_size=3, strides=1 ) node = residual_block( node, "residual_block_2", in_channels=128, out_channels=128, kernel_size=3, strides=1 ) node = residual_block( node, "residual_block_3", in_channels=128, out_channels=128, kernel_size=3, strides=1 ) node = residual_block( node, "residual_block_4", in_channels=128, out_channels=128, kernel_size=3, strides=1 ) node = residual_block( node, "residual_block_5", in_channels=128, out_channels=128, kernel_size=3, strides=1 ) node = conv2d_transpose_layer( node, "de_conv1", in_channels=64, out_channels=128, kernel_size=3, strides=2, activation=True ) node = conv2d_transpose_layer( node, "de_conv2", in_channels=32, out_channels=64, kernel_size=3, strides=2, activation=True ) node = conv2d_layer( node, "de_conv3", in_channels=32, out_channels=3, kernel_size=9, strides=1 ) output = tf.add(tf.nn.tanh(node) * 150, 255.0/2) return output
[docs]def residual_block( input_node, operation_name, in_channels, out_channels, kernel_size, strides ): """Apply a residual block to the network. *input_node* will be the input of the block. *in_channels* should be the number of channels at the input of the block. *out_channels* should be the number of channels at the output of the block. *kernel_size* should be the width and height of the convolution matrix used within the block. *strides* should indicate the stride of the sliding window for each dimension of *input_node*. """ with tf.name_scope(operation_name): node = conv2d_layer( input_node, "rb{}_conv1".format(operation_name[-1]), in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, strides=strides ) node = tf.nn.relu(node) node = conv2d_layer( node, "rb{}_conv2".format(operation_name[-1]), in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, strides=strides ) return input_node + node
[docs]def conv2d_layer( input_node, operation_name, in_channels, out_channels, kernel_size, strides, activation=False ): """Apply a 2-D convolution layer to the network. *input_node* will be the input of the layer. *in_channels* should be the number of channels at the input of the layer. *out_channels* should be the number of channels at the output of the layer. *kernel_size* should be the width and height of the convolution matrix used within the block. *strides* should indicate the stride of the sliding window for each dimension of *input_node*. *activation* should indicate whether a 'relu' node should be added after the convolution layer. """ with tf.name_scope(operation_name): weights_shape = [kernel_size, kernel_size, in_channels, out_channels] weights_init = tf.Variable( tf.truncated_normal(weights_shape, stddev=0.1, seed=1), dtype=tf.float32, name="weights" ) tf.summary.histogram("weights", weights_init) strides_shape = [1, strides, strides, 1] node = tf.nn.conv2d( input_node, weights_init, strides_shape, padding="SAME" ) node = instance_normalization(node, out_channels) if activation: node = tf.nn.relu(node) tf.summary.histogram("activation", node) return node
[docs]def conv2d_transpose_layer( input_node, operation_name, in_channels, out_channels, kernel_size, strides, activation=None ): """Apply a transposed 2-D convolution layer to the network. *input_node* will be the input of the layer. *in_channels* should be the number of channels at the input of the layer. *out_channels* should be the number of channels at the output of the layer. *kernel_size* should be the width and height of the convolution matrix used within the block. *strides* should indicate the stride of the sliding window for each dimension of *input_node*. *activation* should indicate whether a 'relu' node should be added after the convolution layer. """ with tf.name_scope(operation_name): weights_shape = [kernel_size, kernel_size, in_channels, out_channels] weights_init = tf.Variable( tf.truncated_normal(weights_shape, stddev=0.1, seed=1), dtype=tf.float32, name="weights" ) tf.summary.histogram("weights", weights_init) shape = tf.shape(input_node) strides_shape = [1, strides, strides, 1] new_rows = tf.multiply(shape[1], strides) new_columns = tf.multiply(shape[2], strides) new_shape = [shape[0], new_rows, new_columns, in_channels] tf_shape = tf.stack(new_shape) node = tf.nn.conv2d_transpose( input_node, weights_init, tf_shape, strides_shape, padding="SAME" ) node = instance_normalization(node, in_channels) if activation is not None: node = tf.nn.relu(node) tf.summary.histogram("activation", node) return node
[docs]def instance_normalization(input_node, channels): """Apply an instance normalization to the network. *input_node* will be the input of the layer. .. seealso:: Ulyanov et al. (2017). Instance Normalization: The Missing Ingredient for Fast Stylization. `CoRR, abs/1607.08022 <https://arxiv.org/abs/1607.08022>`_. """ with tf.name_scope("instance_normalization"): mu, sigma_sq = tf.nn.moments(input_node, [1, 2], keep_dims=True) shift = tf.Variable(tf.zeros([channels]), name="shift") scale = tf.Variable(tf.ones([channels]), name="scale") epsilon = 1e-3 normalized = (input_node - mu) / (sigma_sq + epsilon) ** .5 return tf.add(tf.multiply(scale, normalized), shift)