Created
January 14, 2020 14:56
-
-
Save RaphaelMeudec/15e940c8645e2a92d49502f126b9b182 to your computer and use it in GitHub Desktop.
How to create a k-way n-shot tf.data.Dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_k_way_n_shot_dataset(annotations, n_shot, k_way, classes=None, to_categorical=True, training=True): | |
"""Build a dataset where each batch contains only N elements of K classes among all classes""" | |
# Prepare a dataframe with "image_path", "x1", "x2", "y1", "y2" columns | |
annotations = annotations.assign(label=pd.Categorical(annotations.label, categories=classes)) | |
# Prepare labels as one hot vectors | |
targets = annotations.label.cat.codes | |
if to_categorical: | |
targets = ( | |
pd.get_dummies(targets) | |
.reindex(list(range(len(targets.unique()))), axis=1) | |
.fillna(0) | |
) | |
num_classes = len(targets.columns) | |
batch_size = n_shot * k_way | |
def load_image_and_crop(image_path, x1, y1, x2, y2): | |
"""Load an image and crop at given positions""" | |
image = tf.io.read_file(image_path) | |
image = tf.image.decode_png(image) | |
image = (tf.image.convert_image_dtype(image, tf.float32) - 0.5) * 2 | |
image = tf.image.crop_to_bounding_box(image, y1, x1, y2 - y1, x2 - x1) | |
image = tf.image.resize_with_crop_or_pad(image, 224, 224) | |
return image | |
def data_aug(image): | |
"""Data augmentations examples""" | |
image = tf.image.random_flip_left_right(image) | |
image = tf.image.random_flip_up_down(image) | |
return image | |
def build_datasets_for_class(annotations, targets, index_class): | |
"""Build a dataset restricted to a given class""" | |
print(f"Building for {index_class}") | |
# Filter all annotations to select only those from selected class | |
class_targets = targets[targets[index_class] > 0] | |
class_annotations = annotations.loc[class_targets.index] | |
# Create dataset | |
dataset = tf.data.Dataset.from_tensor_slices(( | |
class_annotations["image_name"], | |
class_annotations["x1"], | |
class_annotations["y1"], | |
class_annotations["x2"], | |
class_annotations["y2"], | |
class_targets.values.astype("float32"), | |
)) | |
dataset = dataset.map( | |
lambda image_name, x1, y1, x2, y2, target: (load_image_and_crop(image_name, x1, y1, x2, y2), target), | |
num_parallel_calls=tf.data.experimental.AUTOTUNE, | |
) | |
if training: | |
dataset = dataset.cache() | |
dataset = dataset.map( | |
lambda image, target: (data_aug(image), target), | |
num_parallel_calls=tf.data.experimental.AUTOTUNE, | |
) | |
return dataset | |
# Create all filtered dataset | |
datasets_by_class = [build_datasets_for_class(annotations, targets, index_class=index_class).repeat() for index_class in targets.columns] | |
# Create the choice dataset that defines in which dataset we want to pick elements | |
# Typically, if n_shots=4, it will produce something like: | |
# [1, 1, 1, 1, 12, 12, 12, 12, 37, 37, 37, 37, 25, 25, 25, 25, ...] | |
choice_dataset = tf.data.Dataset.range(num_classes).shuffle(buffer_size=num_classes).repeat().interleave( | |
lambda index: tf.data.Dataset.from_tensors(index).repeat(n_shot), | |
cycle_length=1, | |
block_length=n_shot, | |
) | |
# Create the final dataset with choose_from_datasets: it picks elements according to the index generated in the choice_dataset | |
dataset = tf.data.experimental.choose_from_datasets(datasets_by_class, choice_dataset).batch(batch_size) | |
return dataset |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment