Custom image data generator for TF Keras that supports the modern augmentation module albumentations

mjkvaak mjkvaak Last update: Sep 05, 2022

NOTICE!

  • Support has moved from keras to tensorflow.keras framework.
  • There were large updates in Dec 2020, see in Changelog what has changed.

ImageDataAugmentor

ImageDataAugmentor is a custom image data generator for tensorflow.kerasthat supports albumentations.

To learn more about:

Installation

For the installation of the prerequisites, see these two gists: NVIDIA-driver installation and TF2.x installation

$ pip install git+https://github.com/mjkvaak/ImageDataAugmentor

How to use

The usage is analogous to tensorflow.keras.ImageDataGenerator withthe exception that the image transformations will be generated usingexternal augmentations library albumentations.

Tip: Complete list of albumentations.transforms can befound here.See also this handy tool for testingthe different transforms.

The most notable added features are:

  • Augmentations are passed to ImageDataAugmentor as a single albumentations transform(e.g. albumentations.HorizontalFlip()) or a composition of multiple transforms asalbumentations.Compose object
  • albumentations can transform various types of data, e.g. imagery, segmentation mask,bounding box and keypoints.input_augment_mode (resp. label_augment_mode) can be used to select which typeof transforms to apply to the (model) inputs (resp. model labels)
  • .show_data() can be used to visualize a random bunch of imagesgenerated by ImageDataAugmentor

Below are a few examples of some commonly encountered use cases.More complete examples can be found in ./examples folder.

Example of using .flow_from_directory(directory) with albumentations:

import tensorflow as tffrom ImageDataAugmentor.image_data_augmentor import *import albumentations...    AUGMENTATIONS = albumentations.Compose([    albumentations.Transpose(p=0.5),    albumentations.Flip(p=0.5),    albumentations.OneOf([        albumentations.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),        albumentations.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1)    ],p=1),    albumentations.GaussianBlur(p=0.05),    albumentations.HueSaturationValue(p=0.5),    albumentations.RGBShift(p=0.5),])# dataloaderstrain_datagen = ImageDataAugmentor(        rescale=1./255,        augment=AUGMENTATIONS,        preprocess_input=None)train_generator = train_datagen.flow_from_directory(        'data/train',        target_size=(224, 224),        batch_size=32,        class_mode='binary')val_datagen = ImageDataAugmentor(rescale=1./255)validation_generator = val_datagen.flow_from_directory(        'data/validation',        target_size=(224, 224),        batch_size=32,        class_mode='binary')#train_generator.show_data() #<- visualize a bunch of augmented data# train the model with real-time data augmentationsmodel.fit(        train_generator,        steps_per_epoch=len(train_generator),        epochs=50,        validation_data=validation_generator,        validation_steps=len(validation_generator))...

Example of using .flow(x, y) with albumentations:

import tensorflow as tffrom ImageDataAugmentor.image_data_augmentor import *import albumentations...AUGMENTATIONS = albumentations.Compose([    albumentations.HorizontalFlip(p=0.5), # horizontally flip 50% of all images    albumentations.VerticalFlip(p=0.2), # vertically flip 20% of all images    albumentations.ShiftScaleRotate(p=0.5)],)  # fetch data(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()num_classes = len(np.unique(y_train))y_train = tf.keras.utils.to_categorical(y_train, num_classes)y_test = tf.keras.utils.to_categorical(y_test, num_classes)# dataloadersdatagen = ImageDataAugmentor(    featurewise_center=True,    featurewise_std_normalization=True,    augment=AUGMENTATIONS,     validation_split=0.2)# compute quantities required for featurewise normalizationdatagen.fit(x_train, augment=True)train_generator = datagen.flow(x_train, y_train, batch_size=32, subset='training')validation_generator = datagen.flow(x_train, y_train, batch_size=32, subset='validation')# train_generator.show_data()# train the model with real-time data augmentationsmodel.fit(  train_generator,  steps_per_epoch=len(train_generator),  epochs=50,  validation_data=validation_generator,  validation_steps=len(validation_generator))# evaluate the model with test datatest_datagen = ImageDataAugmentor(    featurewise_center=True,    featurewise_std_normalization=True,    augment=albumentations.HorizontalFlip(p=0.5), )test_datagen.mean = datagen.mean #<- stats from training dataset test_datagen.std = datagen.std #<- stats training datasettest_generator = test_datagen.flow(x_test, y_test, batch_size=32)model.evaluate(test_generator)

Example of using .flow_from_directory() with masks for segmentation with albumentations:

import tensorflow as tffrom ImageDataAugmentor.image_data_augmentor import *import albumentations...SEED = 123AUGMENTATIONS = albumentations.Compose([  albumentations.HorizontalFlip(p=0.5),  albumentations.ElasticTransform(),])# Assume that DATA_DIR has subdirs "images" and "masks", # where masks have been saved as grayscale images with pixel value# denoting the segmentation labelDATA_DIR = ... N_CLASSES = ... # number of segmentation classes in masksdef one_hot_encode_masks(y:np.array, classes=range(N_CLASSES)):    ''' One hot encodes target masks for segmentation '''    y = y.squeeze()    masks = [(y == v) for v in classes]    mask = np.stack(masks, axis=-1).astype('float')    # add background if the mask is not binary    if mask.shape[-1] != 1:        background = 1 - mask.sum(axis=-1, keepdims=True)        mask = np.concatenate((mask, background), axis=-1)    return maskimg_data_gen = ImageDataAugmentor(    augment=AUGMENTATIONS,     input_augment_mode='image',     validation_split=0.2,    seed=SEED,)mask_data_gen = ImageDataAugmentor(    augment=AUGMENTATIONS,     input_augment_mode='mask', #<- notice the different augment mode    preprocess_input=one_hot_encode_masks,    validation_split=0.2,    seed=SEED,)print("training:")tr_img_gen = img_data_gen.flow_from_directory(DATA_DIR,                                               classes=['images'],                                               class_mode=None,                                              subset="training",                                               shuffle=True)tr_mask_gen = mask_data_gen.flow_from_directory(DATA_DIR,                                                 classes=['masks'],                                                class_mode=None,                                                 color_mode='gray', #<- notice the color mode                                                subset="training",                                                shuffle=True)print("validation:")val_img_gen = img_data_gen.flow_from_directory(DATA_DIR,                                                classes=['images'],                                               class_mode=None,                                               subset="validation",                                                shuffle=True)val_mask_gen = mask_data_gen.flow_from_directory(DATA_DIR,                                                  classes=['masks'],                                                  class_mode=None,                                                  color_mode='gray', #<- notice the color mode                                                 subset="validation",                                                 shuffle=True)#tr_img_gen.show_data()#tr_mask_gen.show_data()train_generator = zip(tr_img_gen, tr_mask_gen)validation_generator = zip(tr_img_gen, tr_mask_gen)# visualize imagesrows = 5image_batch, mask_batch = next(train_generator)fix, ax = plt.subplots(rows,2, figsize=(4,rows*2))for i, (img,mask) in enumerate(zip(image_batch, mask_batch)):    if i>rows-1:        break    ax[i,0].imshow(np.uint8(img))    ax[i,1].imshow(mask.argmax(-1))    plt.show()# train the model with real-time data augmentationsmodel.fit(  train_generator,  steps_per_epoch=len(train_generator),  epochs=50,  validation_data=validation_generator,  validation_steps=len(validation_generator))...

Citing (BibTex):

@misc{Tukiainen:2019,  author = {Tukiainen, M.},  title = {ImageDataAugmentor},  year = {2019},  publisher = {GitHub},  journal = {GitHub repository},  howpublished = {https://github.com/mjkvaak/ImageDataAugmentor/} }

License

This project is distributed under MIT license.The code is heavily adapted fromhttps://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/ (also MIT licensed)

Subscribe to our newsletter