Last active
May 14, 2020 18:44
-
-
Save RaphaelMeudec/31b7bba0b972ec6ec80ed131a59c5b3f to your computer and use it in GitHub Desktop.
Visualize convolutional kernels with Tensorflow 2.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
# Layer name to inspect | |
layer_name = 'block3_conv1' | |
epochs = 100 | |
step_size = 1. | |
filter_index = 0 | |
# Create a connection between the input and the target layer | |
model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True) | |
submodel = tf.keras.models.Model([model.inputs[0]], [model.get_layer(layer_name).output]) | |
# Initiate random noise | |
input_img_data = np.random.random((1, 224, 224, 3)) | |
input_img_data = (input_img_data - 0.5) * 20 + 128. | |
# Cast random noise from np.float64 to tf.float32 Variable | |
input_img_data = tf.Variable(tf.cast(input_img_data, tf.float32)) | |
# Iterate gradient ascents | |
for _ in range(epochs): | |
with tf.GradientTape() as tape: | |
outputs = submodel(input_img_data) | |
loss_value = tf.reduce_mean(outputs[:, :, :, filter_index]) | |
grads = tape.gradient(loss_value, input_img_data) | |
normalized_grads = grads / (tf.sqrt(tf.reduce_mean(tf.square(grads))) + 1e-5) | |
input_img_data.assign_add(normalized_grads * step_size) |
At the end of the loop, input_img_data
is a 4D tensor holding the generated image. What you want to do is convert it to a numpy with .numpy()
and visualizing it with matplotlib for example
For matplotlib to perform well, you want either to normalize values between 0 and 1, or convert the image to int. As the warnings says, it's clipping values from 0-255 range to 0-1 which makes the image so poor
@konradsemsch I solved it by converting input_img_data
this way:
input_img_data = input_img_data.numpy().astype(np.uint8)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Excuse my stupid question, but how do we actually see/save the image(s)?