"""Implements a variational autoencoder""" import sys import os import shutil import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt import numpy as np def config(): """Return a dict of configuration settings used in the program""" conf = {} # Whether you are training or testing. What is implemented is one of the following # {"train, "test_reconstruction", "test_generation", "test_interpolation"} conf['mode'] = "train" # Whether you are restoring variables from checkpoint, or initializing them from "scratch" # Restoring variables can be useful in testing a trained model, or if you want to continue # training from a checkpoint conf['restore_from_checkpoint'] = False if "test" in conf['mode']: conf['restore_from_checkpoint'] = True conf['job_dir'] = "/tmp/variational_autoencoder/" # Relevant dataset will be put in this location after download conf['data_dir'] = "/tmp/mnist_data" # Location to place checkpoints conf['checkpoint_dir'] = os.path.join(conf['job_dir'], "train/checkpoints") # Location of the summary events to be used by tensorboard conf['summary_dir'] = os.path.join(conf['job_dir'], "{}/events".format(conf['mode'])) # Location to place output conf['output_dir'] = os.path.join(conf['job_dir'], "{}/output".format(conf['mode'])) # Path of checkpoint you want to restore variables from (notice that the postfix is ommitted) conf['restore_path'] = os.path.join(conf['job_dir'], "train/checkpoints/model.ckpt-30000") # Create directories if not os.path.exists(conf['checkpoint_dir']): os.makedirs(conf['checkpoint_dir']) if not os.path.exists(conf['output_dir']): os.makedirs(conf['output_dir']) if not os.path.exists(conf['summary_dir']): os.makedirs(conf['summary_dir']) else: # Remove content of old run before creating a new directory shutil.rmtree(conf['summary_dir']) os.makedirs(conf['summary_dir']) # Number of layers and nodes in the autoencoder. # Number of nodes in the input (the dimensions of an input example) conf['height'] = 28 conf['width'] = 28 conf['channels'] = 1 # This specifies only the hidden layers in the encoder, as the decoder will mirror this setup. conf['hidden_dimensions'] = [128, 64, 10] # Implemented: {'cross_entropy', 'mean_squared_error'} conf['cost_function'] = 'cross_entropy' # Implemented: {'RMSProp', 'Adam', 'GradientDescent} conf['optimization_function'] = 'Adam' # Implemented: {'relu', 'elu', 'sigmoid', 'tanh'} conf['activation_function'] = 'relu' # The number of steps to run before termination of training. One step is one forward->backward # pass of a mini-batch conf['max_steps'] = 30000 # The batch size used in training. conf['batch_size'] = 128 # The step size used by the optimization routine. conf['learning_rate'] = 1.0e-4 # How often (in steps) to log the training progress (to stdout) conf['monitor_progress'] = 1000 # How often (in steps) to save checkpoints conf['periodic_checkpoint'] = 5000 # How many test results to show in a plotted mosaic at the end of training. Preferably a square # number conf['num_display'] = 4 return conf def activation_function(name): """Return an activation function according to the input name""" if name == 'relu': return tf.nn.relu elif name == 'elu': return tf.nn.elu elif name == 'selu': return tf.nn.selu elif name == 'sigmoid': return tf.sigmoid elif name == 'tanh': return tf.tanh else: print("Please specify a valid activation function") sys.exit(1) def kl_divergence(mu, log_sigma_squared): """ Computes the KL divergence between two Gaussian distributions p and q D_KL(p||q) = sum_x p(x) ( log p(x) / q(x) ) where p ~ N(mu, sigma^2) q ~ N(0, 1) """ # Sum over all latent nodes return tf.reduce_sum(1.0/2.0 * (mu*mu + tf.exp(log_sigma_squared) - log_sigma_squared - 1.0), axis=1) def reconstruction_cost(name, references, predictions): """Return a cost function value given the input""" if name == 'mean_squared_error': cost = tf.reduce_mean(tf.pow(predictions - references, 2)) elif name == 'cross_entropy': # Binary cross entropy is used to compare the predictions and the reference epsilon = 1e-10 # To avoid log(0) loss = -tf.reduce_sum(references * tf.log(epsilon + predictions) + \ (1.0 - references) * tf.log(epsilon + 1.0 - predictions), axis=1) cost = tf.reduce_mean(loss) # Average over examples else: print("Please specify an implemented cost function") sys.exit(1) tf.summary.scalar("Reconstruction cost", cost) return cost def latent_cost(mu, log_sigma_squared): """Return latent cost function. Regularizes the spread.""" cost = tf.reduce_mean(kl_divergence(mu, log_sigma_squared)) # Average over all examples tf.summary.scalar("Latent cost", cost) return cost def optimization_function(name, learning_rate, cost): """Return an optimization function given the input""" if name == 'RMSProp': return tf.train.RMSPropOptimizer(learning_rate).minimize(cost) elif name == 'Adam': return tf.train.AdamOptimizer(learning_rate).minimize(cost) elif name == 'sgd': return tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) else: print("Please specify an implemented optimizer") sys.exit(1) def dense_layer(name, input_activations, num_nodes, activation): """Creates a dense layer with num_nodes""" num_nodes_prev = input_activations.get_shape().as_list()[-1] print("Layer {0:<10}: {1:>4} -> {2:>4}".format(name, num_nodes_prev, num_nodes)) # Samples weights from an uniform distribution in the range +/- sqrt(6 / (in_dim + out_dim)) weight_initializer = tf.contrib.layers.xavier_initializer() bias_initializer = tf.zeros_initializer() weights = tf.get_variable(name="W_" + name, shape=[num_nodes_prev, num_nodes], dtype=tf.float32, initializer=weight_initializer) biases = tf.get_variable(name="b_" + name, shape=[num_nodes], dtype=tf.float32, initializer=bias_initializer) linear = tf.add(tf.matmul(input_activations, weights), biases) if activation is None: return linear activation_fn = activation_function(activation) return activation_fn(linear) def encoder(x, conf): """Builds the encoder. This assumes that the input image has been flattened""" # Note that we iterate up to (excluding) the latent layer for layer_ind, num_nodes in enumerate(conf['hidden_dimensions'][:-1], start=1): x = dense_layer("encoding_{}".format(layer_ind), x, num_nodes, conf['activation_function']) mu = dense_layer("mu", x, conf['hidden_dimensions'][-1], None) # We sample the log(sigma_squared) in stead of sigma_squared to allow for negative samples log_sigma_squared = dense_layer("sigma", x, conf['hidden_dimensions'][-1], None) standard_normal = tf.random_normal(shape=tf.shape(mu), mean=0.0, stddev=1.0, dtype=tf.float32) latent = mu + standard_normal * tf.sqrt(tf.exp(log_sigma_squared)) return latent, mu, log_sigma_squared def decoder(x, conf): """Builds the decoder and returns both the predictions and logits""" for layer_ind, num_nodes in enumerate(reversed(conf['hidden_dimensions'][:-1]), start=1): x = dense_layer("decoding_{}".format(layer_ind), x, num_nodes, conf['activation_function']) num_nodes = conf['height'] * conf['width'] * conf['channels'] layer_ind = len(conf['hidden_dimensions']) logits = dense_layer("decoding_{}".format(layer_ind), x, num_nodes, None) predictions = tf.sigmoid(logits) return predictions def autoencoder(input_batch, conf): """Define the autoencoder model""" latent_batch, mu, log_sigma_squared = encoder(input_batch, conf) prediction = decoder(latent_batch, conf) return prediction, latent_batch, mu, log_sigma_squared def train(conf, data): """Train the model""" learning_rate = conf['learning_rate'] input_ph = tf.placeholder(name="input_batch", shape=[None, conf['height'] * conf['width'] * conf['channels']], dtype=tf.float32) predictions, _, mu, log_sigma_squared = autoencoder(input_ph, conf) references = input_ph rec_cost = reconstruction_cost(conf['cost_function'], references, predictions) lat_cost = latent_cost(mu, log_sigma_squared) cost = tf.reduce_mean(rec_cost + lat_cost) tf.summary.scalar("Total cost", cost) train_op = optimization_function(conf['optimization_function'], learning_rate, cost) summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) summary_op = tf.summary.merge(summaries) saver = tf.train.Saver(tf.global_variables(), max_to_keep=20) if conf['restore_from_checkpoint']: restorer = tf.train.Saver() else: init = tf.global_variables_initializer() with tf.Session() as sess: if conf['restore_from_checkpoint']: print('Restoring variables from {}'.format(conf['restore_path'])) restorer.restore(sess, conf['restore_path']) else: sess.run(init) summary_writer = tf.summary.FileWriter(conf['summary_dir'], sess.graph) # Training loop step = 0 while step <= conf['max_steps']: # Get data, do not need labels input_batch, _ = data.train.next_batch(conf['batch_size']) # Run optimization and backpropagation _, cost_val, recon_cost_val, latent_cost_val, summary_str = sess.run( [train_op, cost, rec_cost, lat_cost, summary_op], feed_dict={input_ph: input_batch}) summary_writer.add_summary(summary_str, step) step += 1 last_step = True if step == conf['max_steps'] else False if (step % conf['monitor_progress'] == 0) or last_step: print('Step {0:>6}: Total cost: {1:>7,.4e}, ' 'reconstruction cost {2:>7,.4e} latent cost {3:>7,.4e}'.format( step, cost_val, recon_cost_val, latent_cost_val)) if (step % conf['periodic_checkpoint'] == 0) or last_step: print('Writing checkpoint') checkpoint_path = os.path.join(conf['checkpoint_dir'], 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) print("Training finished") def test_reconstruction(conf, data): """Test a trained model. This reconstructs a bunch of test images.""" input_ph = tf.placeholder(name="input_batch", shape=[None, conf['height'] * conf['width'] * conf['channels']], dtype=tf.float32) predictions, _, _, _ = autoencoder(input_ph, conf) restorer = tf.train.Saver() with tf.Session() as sess: print('Restoring variables from {}'.format(conf['restore_path'])) restorer.restore(sess, conf['restore_path']) # Testing originals = [] reconstructed = [] for _ in range(conf['num_display']): input_batch, _ = data.test.next_batch(1) preds = sess.run(predictions, feed_dict={input_ph: input_batch}) originals.append(input_batch[0].reshape([conf['height'], conf['width']])) reconstructed.append(preds[0].reshape([conf['height'], conf['width']])) mosaic_size = int(np.ceil(np.sqrt(conf['num_display']))) plt.figure(0, figsize=(mosaic_size*2, mosaic_size*2)) for ind, im in enumerate(originals): plt.subplot(mosaic_size, mosaic_size, ind+1) plt.imshow(im, origin="upper", cmap="gray", clim=(0.0, 1.0)) plt.axis('off') plt.savefig(os.path.join(conf['output_dir'], 'mosaic_original.png')) plt.figure(1, figsize=(mosaic_size*2, mosaic_size*2)) for ind, im in enumerate(reconstructed): plt.subplot(mosaic_size, mosaic_size, ind+1) plt.imshow(im, origin="upper", cmap="gray", clim=(0.0, 1.0)) plt.axis('off') plt.savefig(os.path.join(conf['output_dir'], 'mosaic_reconstructed.png')) plt.show() def test_generation(conf): """Test a trained model. This generates an example from random latent variables.""" latent_vec_ph = tf.placeholder(name="latent_vector", shape=[None, conf['hidden_dimensions'][-1]], dtype=tf.float32) predictions = decoder(latent_vec_ph, conf) restorer = tf.train.Saver() with tf.Session() as sess: print('Restoring variables from {}'.format(conf['restore_path'])) restorer.restore(sess, conf['restore_path']) # Testing generated = [] for _ in range(conf['num_display']): # Sample a single vector from the standard normal distribution latent_vec = np.random.normal(size=[1, conf['hidden_dimensions'][-1]]) preds = sess.run(predictions, feed_dict={latent_vec_ph: latent_vec}) generated.append(preds[0].reshape([conf['height'], conf['width']])) mosaic_size = int(np.ceil(np.sqrt(conf['num_display']))) plt.figure(0, figsize=(mosaic_size*2, mosaic_size*2)) for ind, im in enumerate(generated): plt.subplot(mosaic_size, mosaic_size, ind+1) plt.imshow(im, origin="upper", cmap="gray", clim=(0.0, 1.0)) plt.axis('off') plt.savefig(os.path.join(conf['output_dir'], 'mosaic_generated.png')) plt.show() def test_interpolation(conf, data): """ Test a trained model. This generates an example that should be an interpolation between two test examples. """ # Note batch size = 2 input_batch_ph = tf.placeholder(name="input_batch", shape=[2, conf['height'] * conf['width'] * conf['channels']], dtype=tf.float32) _, mu_vec, _ = encoder(input_batch_ph, conf) mu_1 = mu_vec[0, :] mu_2 = mu_vec[1, :] latent_vector = tf.expand_dims(1.0 / 2.0 * (mu_1 + mu_2), 0) predictions = decoder(latent_vector, conf) restorer = tf.train.Saver() with tf.Session() as sess: print('Restoring variables from {}'.format(conf['restore_path'])) restorer.restore(sess, conf['restore_path']) # Testing interpolated = [] for _ in range(conf['num_display']): # Sample a single vector from the standard normal distribution latent_vec = np.random.normal(size=[1, conf['hidden_dimensions'][-1]]) input_batch, label_batch = data.test.next_batch(2) # Note batch size = 2 preds = sess.run(predictions, feed_dict={input_batch_ph: input_batch}) # Push a tuple of examples to the list interpolated.append((input_batch[0].reshape([conf['height'], conf['width']]), np.argmax(label_batch[0]), input_batch[1].reshape([conf['height'], conf['width']]), np.argmax(label_batch[1]), preds[0].reshape([conf['height'], conf['width']]))) plt.figure(0, figsize=(5, conf['num_display']*2)) for ind, tup in enumerate(interpolated): ex1, lab1, ex2, lab2, pred = tup ax1 = plt.subplot(conf['num_display'], 3, 3 * ind + 1) ax1.set_title("Label " + str(lab1)) ax1.imshow(ex1, origin="upper", cmap="gray", clim=(0.0, 1.0)) ax1.axis('off') ax2 = plt.subplot(conf['num_display'], 3, 3 * ind + 2) ax2.set_title("Label " + str(lab2)) ax2.imshow(ex2, origin="upper", cmap="gray", clim=(0.0, 1.0)) ax2.axis('off') ax3 = plt.subplot(conf['num_display'], 3, 3 * ind + 3) ax3.set_title("Interpolated") ax3.imshow(pred, origin="upper", cmap="gray", clim=(0.0, 1.0)) ax3.axis('off') plt.savefig(os.path.join(conf['output_dir'], 'mosaic_interpolated.png')) plt.show() def main(): """Main""" print("Start program") conf = config() mnist = input_data.read_data_sets(conf['data_dir'], one_hot=True) if conf['mode'] == 'train': train(conf, mnist) elif conf['mode'] == 'test_reconstruction': test_reconstruction(conf, mnist) elif conf['mode'] == 'test_generation': test_generation(conf) elif conf['mode'] == 'test_interpolation': test_interpolation(conf, mnist) if __name__ == "__main__": main()