Created
August 10, 2018 20:43
-
-
Save rplzzz/ba952cc6b07ee9d08ea08f2912126572 to your computer and use it in GitHub Desktop.
How to use the transpose convolution function in TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## a batch of 3 4x4x2 images as input. We will upsample these to 8x8 | |
input_data = np.ones([3, 4, 4, 2]) | |
filter_shape = [3,3,4,2] # width, height, channels-out, channels-in | |
tconv_filter = np.ones(filter_shape) # make the filter all ones so that we can manually calculate the output | |
# obviously, having all the filter channels have the same coefficients defeats | |
# the purpose of having multiple channels, but this is just an example | |
output_shape = [8, 8, 4] # width, height, channels-out (notice we don't have the batch size dimension -- more on that later) | |
## set up the slots for the data | |
xin = tf.placeholder(dtype=tf.float32, shape = (None, 4, 4, 2), name='input') | |
filt = tf.placeholder(dtype=tf.float32, shape = filter_shape, name='filter') | |
## Run the transpose convolution. With a stride of 2, this should upsample our image by a factor of 2 | |
## You have to specify the output shape, but as far as I can tell, it's not a free parameter; it's determined | |
## by the choice of stride. What's annoying is that the first dimension is usually unknown, but we have to | |
## include it, so we have to extract it and concatenate it onto the front of output_shape. | |
dimxin = tf.shape(xin) | |
ncase = dimxin[0:1] | |
oshp = tf.concat([ncase,output_shape], axis=0) | |
z1 = tf.nn.conv2d_transpose(xin, filt, oshp, strides=[1,2,2,1], name='xpose_conv') | |
## tf.layers has an all-in-one transpose convolution layer. It's a lot more convenient, but | |
## you don't get to specify the weights (they get initialized randomly). In this case I wanted to use | |
## specified weights so I could see what the actual effect is. Note that while the default padding for | |
## tf.nn.conv2d_transpose is 'same', the default for this function is 'valid'. You're @#$%ing killing | |
## me, Google. | |
z2 = tf.layers.conv2d_transpose(xin, 4, (3,3), strides=(2,2), padding='SAME') | |
with tf.Session() as sess: | |
summary_writer = tf.summary.FileWriter('logs', sess.graph) | |
sess.run(tf.global_variables_initializer()) | |
(z1out, z2out) = sess.run(fetches=[z1,z2], | |
feed_dict={xin:input_data, filt:tconv_filter}) | |
print(z1out.shape) | |
print(z2out.shape) | |
print(z1out[0, ..., 0]) | |
Incidentally, in the output above you can see the well-known checkerboard pattern that transpose convolutions are often criticized for. Many sources recommend following this operation with a same-size convolution to smooth out the artifacts.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here's the output:
z2out will be filled with random values, since the weights are initialized randomly.