Created
January 16, 2020 11:17
-
-
Save RaphaelMeudec/abf6ece31c5492c053fd8a4b6a56d02e to your computer and use it in GitHub Desktop.
Create an optimized version of tf.data Dataset for an image deblurring task
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import tensorflow as tf | |
def select_patch(sharp, blur, patch_size_x, patch_size_y): | |
""" | |
Select a patch on both sharp and blur images at the same localization. | |
Args: | |
sharp (tf.Tensor): Tensor for the sharp image | |
blur (tf.Tensor): Tensor for the blur image | |
patch_size_x (int): Size of patch along x axis | |
patch_size_y (int): Size of patch along y axis | |
Returns: | |
Tuple[tf.Tensor, tf.Tensor]: Tuple of tensors with shape (patch_size_x, patch_size_y, 3) | |
""" | |
stack = tf.stack([sharp, blur], axis=0) | |
patches = tf.image.random_crop(stack, size=[2, patch_size_x, patch_size_y, 3]) | |
return (patches[0], patches[1]) | |
class TensorflowDatasetLoader: | |
def __init__(self, dataset_path, batch_size=4, patch_size=(256, 256), n_epochs=10, n_images=None): | |
# List all images paths | |
sharp_images_paths = [str(path) for path in Path(dataset_path).glob("*/sharp/*.png")] | |
if n_images is not None: | |
sharp_images_paths = sharp_images_paths[0:n_images] | |
# Generate corresponding blurred images paths | |
blur_images_paths = [path.replace("sharp", "blur") for path in sharp_images_paths] | |
# Load sharp and blurred images | |
sharp_dataset = tf.data.Dataset.from_tensor_slices(sharp_images_paths).map( | |
lambda path: self.load_image(path, dtype), | |
num_parallel_calls=tf.data.experimental.AUTOTUNE, | |
) | |
blur_dataset = tf.data.Dataset.from_tensor_slices(blur_images_paths).map( | |
lambda path: self.load_image(path, dtype), | |
num_parallel_calls=tf.data.experimental.AUTOTUNE, | |
) | |
dataset = tf.data.Dataset.zip((sharp_dataset, blur_dataset)) | |
dataset = dataset.cache() | |
# Select the same patch on the sharp image and its corresponding blurred | |
dataset = dataset.map( | |
lambda sharp_image, blur_image: select_patch( | |
sharp_image, blur_image, patch_size[0], patch_size[1] | |
), | |
num_parallel_calls=tf.data.experimental.AUTOTUNE, | |
) | |
# Define dataset characteristics (batch_size, number_of_epochs, shuffling) | |
dataset = dataset.batch(batch_size) | |
dataset = dataset.shuffle(buffer_size=50) | |
dataset = dataset.repeat() | |
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) | |
self.dataset = dataset | |
@staticmethod | |
def load_image(image_path, dtype): | |
image = tf.io.read_file(image_path) | |
image = tf.image.decode_png(image, channels=3) | |
image = tf.image.convert_image_dtype(image, dtype) | |
image = (image - 0.5) * 2 | |
return image |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment