-
-
Save RaphaelMeudec/39b85509f9d8f41caffaf83525adced8 to your computer and use it in GitHub Desktop.
Keras implementation of Generator for DeblurGAN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Input, Activation, Add | |
from keras.layers.advanced_activations import LeakyReLU | |
from keras.layers.convolutional import Conv2D, Conv2DTranspose | |
from keras.layers.core import Lambda | |
from keras.layers.normalization import BatchNormalization | |
from keras.models import Model | |
from layer_utils import ReflectionPadding2D, res_block | |
ngf = 64 | |
input_nc = 3 | |
output_nc = 3 | |
input_shape_generator = (256, 256, input_nc) | |
n_blocks_gen = 9 | |
def generator_model(): | |
"""Build generator architecture.""" | |
# Current version : ResNet block | |
inputs = Input(shape=image_shape) | |
x = ReflectionPadding2D((3, 3))(inputs) | |
x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
# Increase filter number | |
n_downsampling = 2 | |
for i in range(n_downsampling): | |
mult = 2**i | |
x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
# Apply 9 ResNet blocks | |
mult = 2**n_downsampling | |
for i in range(n_blocks_gen): | |
x = res_block(x, ngf*mult, use_dropout=True) | |
# Decrease filter number to 3 (RGB) | |
for i in range(n_downsampling): | |
mult = 2**(n_downsampling - i) | |
x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3,3), strides=2, padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
x = ReflectionPadding2D((3,3))(x) | |
x = Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x) | |
x = Activation('tanh')(x) | |
# Add direct connection from input to output and recenter to [-1, 1] | |
outputs = Add()([x, inputs]) | |
outputs = Lambda(lambda z: z/2)(outputs) | |
model = Model(inputs=inputs, outputs=outputs, name='Generator') | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment