Variational Autoencoders: Training Procedures (Part 2)
Welcome to Part 2 of our three-part series on Variational Autoencoders (VAEs). Building on our introduction and implementation details covered in Part 1, this segment will focus on the training procedures which are integral to fine-tuning these models. We will take a deep dive into the significance of each parameter and how they contribute to the training process. To access the complete code for this series, please visit our GitHub repository at https://github.com/asokraju/ImageAutoEncoder.
In this article, we outline the steps involved in setting up and managing the training process for a Variational Autoencoder (VAE), a sophisticated deep learning model particularly effective in tasks associated with image generation and modification.
The training journey begins by accepting a range of parameters and hyperparameters which are instrumental in shaping various aspects of the model and the training process. These parameters encompass aspects such as the location of image data, storage path for the model’s training logs, the number of training epochs, batch size, and learning rate, among others.
Once these parameters are set, we establish the paths to the training and testing data. The data is split into a training subset for the model to learn from, and a testing subset to gauge the model’s performance on data it hasn’t encountered before.
Next, we set up data generators, a critical component when dealing with large datasets. These generators apply a series of transformations (known as data augmentation) to the training images, enabling the model to generalize better and mitigate overfitting. For the testing data, we abstain from applying these transformations.
The generators then take on the task of loading and processing the images in small batches, as defined by the batch size. This process enables the model to incrementally update its weights, step by step, instead of attempting to compute the entire dataset simultaneously.
With the data ready and the generators primed, the VAE model is now ready to be trained using the specified parameters and hyperparameters. This process is logged for future analysis, improvements, and debugging purposes.
Parsing the hyperparameters
We start by defining a method parse_arguments()
to receive various parameters and hyperparameters for model training.
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--image-dir', type=str, help='Path to the image data', default= r'Data')
parser.add_argument('--logs-dir', type=str, help='Path to store logs', default=r"logs")
parser.add_argument('--output-image-shape', type=int, default=56)
parser.add_argument('--filters', type=int, nargs='+', default=[32, 64])
parser.add_argument('--dense-layer-dim', type=int, default=16)
parser.add_argument('--latent-dim', type=int, default=6)
parser.add_argument('--beta', type=float, default=1.0)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--learning-rate', type=float, default=1e-4)
parser.add_argument('--patience', type=int, default=10)
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--train-split', type=float, default=0.8)
args = parser.parse_args()
return args
These arguments include paths to the image data and logs directory, output image shape, number of filters for the convolutional layers, dimensions of the dense and latent layers, batch size, learning rate, and the number of epochs.
--image-dir
: This is the directory where your training images are located.--logs-dir
: This is where the training logs will be stored. These logs are beneficial for debugging and visualizing the training progress.--output-image-shape
: Defines the shape of the output image. It's important to note that the VAE needs to output an image of the same shape as the input image.--filters
: These are the number of filters for the convolutional layers of the encoder and decoder models.--dense-layer-dim
: This is the dimensionality of the dense layer in the encoder model.--latent-dim
: This defines the dimensionality of the latent space. It is crucial to choose an appropriate size for the latent space because it will determine the capacity of the VAE to represent complex data.--batch-size
: The number of training examples utilized in one iteration. This can significantly affect your model's training performance.--learning-rate
: The size of the steps the optimizer takes while learning. This needs to be tuned carefully as a value too small may result in slow convergence, while a value too large may prevent convergence.--epochs
: The number of complete passes through the training dataset. The right number of epochs usually depends on how soon the model starts to overfit.
Next, the main script starts, parsing these arguments and setting up the data paths as follows:
args = parse_arguments()
IMAGE_DIR = args.image_dir
LOGS_DIR = args.logs_dir
all_image_paths = get_image_data(IMAGE_DIR)
image_count = len(all_image_paths)
TRAIN_SPLIT = args.train_split
OUTPUT_IMAGE_SHAPE = args.output_image_shape
INPUT_SHAPE = (OUTPUT_IMAGE_SHAPE, OUTPUT_IMAGE_SHAPE, 1)
FILTERS = args.filters
DENSE_LAYER_DIM = args.dense_layer_dim
LATENT_DIM = args.latent_dim
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
LEARNING_RATE = args.learning_rate
LOGDIR = os.path.join(LOGS_DIR, datetime.now().strftime("%Y%m%d-%H%M%S"))
os.mkdir(LOGDIR)
Data Extraction, Preprocessing, and Segregation for Training and Testing
As outlined earlier, the next step is the extraction of image data and its segregation into training and testing sets.
We first gather all the image paths into a list using the get_image_data()
function. We then count the total number of images to later divide them for training and testing.
def get_image_data(all_dirs):
# List to store all image file paths
all_dirs = [all_dirs]
all_image_paths = []
# Loop through all directories and subdirectories in the data directory
for data_dir in all_dirs:
for root, dirs, files in os.walk(data_dir):
for file in files:
# Check if the file is an image file (you can add more extensions as needed)
if file.endswith('.jpg') or file.endswith('.jpeg') or file.endswith('.png'):
# If the file is an image file, append its path to the list
all_image_paths.append(os.path.join(root, file))
print(data_dir)
image_count = len(all_image_paths)
print("Total number of imges:", image_count)
return all_image_paths
Once we have our list of images, we split it into two parts: the larger portion for training (as dictated by TRAIN_SPLIT
), and the smaller portion for testing. This is accomplished by slicing the list of image paths according to the split ratio, creating two separate lists. We then store these image paths in respective dataframes, df_train
and df_test
.
all_image_paths = get_image_data(IMAGE_DIR)
image_count = len(all_image_paths)
df_train = pd.DataFrame({'image_paths': all_image_paths[:int(image_count*TRAIN_SPLIT)]})
df_test = pd.DataFrame({'image_paths': all_image_paths[int(image_count*TRAIN_SPLIT):]})
The next part of the process involves preparing our data generators. These are essentially pipelines that handle image loading, preprocessing, and batching. We define separate generators for training and testing data. For the training data, we use image augmentation techniques such as shearing, zooming, flipping, and rotating to artificially increase our dataset size and variability, which helps in better generalization during the training process. For the testing data, we simply normalize pixel values.
With the data generators defined, we then apply the function flow_from_dataframe()
on our training and testing dataframes. This function generates batches of images directly from the given dataframe, converting the images to grayscale, resizing them to the desired shape, and finally shuffling them to ensure randomness in the training process.
train_datagen_args = dict(
rescale=1.0 / 255, # Normalize pixel values between 0 and 1
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
rotation_range=90,
width_shift_range=0.1,
height_shift_range=0.1,
)
test_datagen_args = dict(rescale=1.0 / 255)
train_datagen = ImageDataGenerator(**train_datagen_args)
test_datagen = ImageDataGenerator(**test_datagen_args)
# Use flow_from_dataframe to generate data batches
train_data_generator = train_datagen.flow_from_dataframe(
dataframe=df_train,
color_mode='grayscale',
x_col='image_paths',
y_col=None,
target_size=(OUTPUT_IMAGE_SHAPE, OUTPUT_IMAGE_SHAPE), # Specify the desired size of the input images
batch_size=BATCH_SIZE,
class_mode=None, # Set to None since there are no labels
shuffle=True # Set to True for randomizing the order of the images
)
test_data_generator = test_datagen.flow_from_dataframe(
dataframe=df_test,
color_mode='grayscale',
x_col='image_paths',
y_col=None,
target_size=(OUTPUT_IMAGE_SHAPE, OUTPUT_IMAGE_SHAPE), # Specify the desired size of the input images
batch_size=BATCH_SIZE,
class_mode=None, # Set to None since there are no labels
shuffle=True # Set to True for randomizing the order of the images
)
This process results in a stream of processed batches of images ready for training and testing of our Variational Autoencoder model.
Training
Following the establishment of our data pipeline, we proceed to the crucial step of building our Variational Autoencoder model. This is a two-part process, involving the assembly of an encoder and a decoder, which were discussed in the first part of this series. Once we’ve put together these individual components, we combine them to create the complete VAE model.
In a nutshell, our encoder module is responsible for translating our input images into a set of parameters defining a distribution in the latent space. The decoder then takes a point from this distribution and transforms it back into the original image space. Thus, the VAE learns to encode useful information about the input data into the latent space, which can then be used for various purposes such as generating new images, finding similar images, and more.
In addition to building the model, we also define a custom metric class TotalLoss
to monitor the performance of our VAE during training.
# custom metrics
class TotalLoss(Metric):
def __init__(self, name="total_loss", **kwargs):
super(TotalLoss, self).__init__(name=name, **kwargs)
self.total_loss = self.add_weight(name="tl", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
# Compute the total loss
z_mean, z_log_var, z, reconstruction = y_pred
reconstruction_loss = reduce_mean(
reduce_sum(
binary_crossentropy(y_true, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf_square(z_mean) - tf_exp(z_log_var))
kl_loss = reduce_mean(reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
self.total_loss.assign(total_loss)
def result(self):
return self.total_loss
def reset_states(self):
# The state of the metric will be reset at the start of each epoch.
self.total_loss.assign(0.)
This metric calculates the total loss, which includes both the reconstruction loss and the KL divergence loss, providing a holistic measure of how well the VAE is learning to both recreate the input images and maintain a well-behaved distribution in the latent space.
encoder, encoder_layers_dim = encoder_model(input_shape = INPUT_SHAPE, filters=FILTERS, dense_layer_dim=DENSE_LAYER_DIM, latent_dim=LATENT_DIM)
decoder = decoder_model(encoder_layers_dim)vae = VAE(encoder, decoder)
vae.compile(optimizer=Adam(learning_rate=LEARNING_RATE), metrics=[TotalLoss()])
To break this down further:
- Building the encoder: We first call our
encoder_model
function, which returns the encoder model and the dimensions of its layers. The encoder model takes as input the shape of the input images, the number of filters to use in each convolutional layer, the dimension of the dense layer, and the dimension of the latent space. - Building the decoder: Next, we call the
decoder_model
function, providing it with the dimensions of the encoder layers. The decoder model uses these dimensions to construct a series of deconvolutional layers that mirror the structure of the encoder, allowing it to accurately reconstruct the input images from the latent space. - Combining the encoder and decoder: Once we have our encoder and decoder models, we instantiate our VAE model by passing these models to the
VAE
class. This class combines the encoder and decoder into a single end-to-end model that can be trained on our image data. - Defining the total loss metric: Lastly, we define a custom metric to monitor the total loss of our model during training. This
TotalLoss
class is a subclass of the KerasMetric
class, and it calculates the total loss by summing the reconstruction loss and the KL divergence loss. This gives us a single value that reflects the overall performance of our model.
Callbacks in TensorFlow and Keras provide a way to execute certain actions at various stages of training. They are an important part of the training process as they allow us to add custom behaviors during training.
class VAECallback(Callback):
"""
Randomly sample 5 images from validation_data set and shows the reconstruction after each epoch
"""
def __init__(self, vae, validation_data, log_dir, n=5):
self.vae = vae
self.validation_data = validation_data
self.n = n
self.log_dir = log_dir
def on_epoch_end(self, epoch, logs=None):
# check every 10 epochs
if epoch % 10 ==0:
# Generate decoded images from the validation input
validation_batch = next(iter(self.validation_data))
_, _, _, reconstructed_images = self.vae.predict(validation_batch)
# Rescale pixel values to [0, 1]
reconstructed_images = np.clip(reconstructed_images, 0.0, 1.0)
# Plot the original and reconstructed images side by side
plt.figure(figsize=(10, 2*self.n)) # Adjusted the figure size
for i in range(self.n):
plt.subplot(self.n, 2, 2*i+1)
plt.imshow(validation_batch[i], cmap='gray')
plt.axis('off')
plt.subplot(self.n, 2, 2*i+2)
plt.imshow(reconstructed_images[i], cmap='gray')
plt.axis('off')
fig_name = os.path.join(self.log_dir , 'decoded_images_epoch_{:04d}.png'.format(epoch))
plt.savefig(fig_name)
# plt.show()
vae_callback = VAECallback(vae, test_data_generator, LOGDIR)
tensorboard_cb = TensorBoard(log_dir=LOGDIR, histogram_freq=1)
checkpoint_cb = ModelCheckpoint(filepath=vae_path, save_weights_only=True, verbose=1)
earlystopping_cb = EarlyStopping(monitor="total_loss",min_delta=1e-2,patience=5,verbose=1,)
In this code, four types of callbacks are used:
- VAECallback: This custom callback samples a few images from the validation dataset and visualizes the VAE’s reconstruction of these images at the end of each training epoch. By saving these images, we can visually track how the performance of our model improves over time.
- TensorBoard: TensorBoard is a visualization tool provided with TensorFlow. This callback logs various metrics and parameters for each epoch, allowing you to visualize them in TensorBoard. This can be useful for monitoring the training process and diagnosing issues.
- ModelCheckpoint: This callback saves the model weights at certain intervals, so you can use them to continue training later or to evaluate the performance of your model on different metrics. In this case, it saves the weights of the best model (as measured by the validation loss) seen so far.
- EarlyStopping: This callback stops training when a monitored metric has stopped improving, which in this case is the
total_loss
. It is useful to prevent overfitting and reduce computational waste. The “patience” parameter is the number of epochs to wait before stopping after the metric has stopped improving.
Finally, the fit
function is called on the VAE model to start the training process. The training data, number of epochs, validation data, and callbacks are passed to this function.
history = vae.fit(
train_data_generator,
epochs=EPOCHS,
validation_data=test_data_generator,
callbacks=[tensorboard_cb, vae_callback, checkpoint_cb, earlystopping_cb]
)
At the end of the training process, the history
object contains the loss and metric values at each epoch, which can be used to plot learning curves and evaluate the model.
Thus, to conclude, in this part of the series, we discussed the process of training Variational Autoencoders, including setting up the data pipeline, building the model, defining custom metrics, and using callbacks. In the next and final part of this series, we will delve into hyperparameter tuning to optimize the performance of our model. Stay tuned!