Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save RaphaelMeudec/abf6ece31c5492c053fd8a4b6a56d02e to your computer and use it in GitHub Desktop.
Save RaphaelMeudec/abf6ece31c5492c053fd8a4b6a56d02e to your computer and use it in GitHub Desktop.
Create an optimized version of Dataset for an image deblurring task
from pathlib import Path
import tensorflow as tf
def select_patch(sharp, blur, patch_size_x, patch_size_y):
Select a patch on both sharp and blur images at the same localization.
sharp (tf.Tensor): Tensor for the sharp image
blur (tf.Tensor): Tensor for the blur image
patch_size_x (int): Size of patch along x axis
patch_size_y (int): Size of patch along y axis
Tuple[tf.Tensor, tf.Tensor]: Tuple of tensors with shape (patch_size_x, patch_size_y, 3)
stack = tf.stack([sharp, blur], axis=0)
patches = tf.image.random_crop(stack, size=[2, patch_size_x, patch_size_y, 3])
return (patches[0], patches[1])
class TensorflowDatasetLoader:
def __init__(self, dataset_path, batch_size=4, patch_size=(256, 256), n_epochs=10, n_images=None):
# List all images paths
sharp_images_paths = [str(path) for path in Path(dataset_path).glob("*/sharp/*.png")]
if n_images is not None:
sharp_images_paths = sharp_images_paths[0:n_images]
# Generate corresponding blurred images paths
blur_images_paths = [path.replace("sharp", "blur") for path in sharp_images_paths]
# Load sharp and blurred images
sharp_dataset =
lambda path: self.load_image(path, dtype),,
blur_dataset =
lambda path: self.load_image(path, dtype),,
dataset =, blur_dataset))
dataset = dataset.cache()
# Select the same patch on the sharp image and its corresponding blurred
dataset =
lambda sharp_image, blur_image: select_patch(
sharp_image, blur_image, patch_size[0], patch_size[1]
# Define dataset characteristics (batch_size, number_of_epochs, shuffling)
dataset = dataset.batch(batch_size)
dataset = dataset.shuffle(buffer_size=50)
dataset = dataset.repeat()
dataset = dataset.prefetch(
self.dataset = dataset
def load_image(image_path, dtype):
image =
image = tf.image.decode_png(image, channels=3)
image = tf.image.convert_image_dtype(image, dtype)
image = (image - 0.5) * 2
return image
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment