Batch Normalization is a straightforward way for optimizing the training of deep neural networks. It is based on the idea that inputs to all layers of a neural network should be whitened - i.e. linearly transformed to have zero mean and unit variance, before being fed into the activation function. This results in eliminating the effects of what the authors of the paper introducing this technique call "Internal Covariate Shift", which is the change in the distributions of each the layer activations as we move forward in a network.

The term "Covariate Shift" refers to a change in the input distribution to any learning system and can be attributed to Shimodaira, 2000. When we consider a deep neural network not as a singular learning system, but as a composition of trainable layers, we can apply the concept of covariate shift to all the layers in the network.

Let's explore the 2015 paper by Sergey Ioffe and Christian Szegedy, introducing the batch normalization technique further.

The problem

While training deep neural networks, the distribution of each layer's inputs changes as the parameters of the previous layers are updated. This covariate shift becomes an issue especially in deep neural nets with many layers as we end up constrained with lower learning rates and extremely sensitive parameter initialization strategies. Another challenge that presents itself is when working with saturating non-linearities like the tanh non-linearity.

To see how this can be a problem let's evaluate the statistics of the activations in a simple feed-forward neural net with 10 hidden layers containing 200 nodes each.

# Code adapted from Andrej Karpathy's CS231n slides.

import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt

inputs = np.random.randn(500, 200)
hidden_layers = [200] * 10
nonlinearities = ['tanh']*len(hidden_layers)

activations = {
'relu': lambda x:np.maximum(0, x),
'tanh': lambda x:np.tanh(x)}

hidden_activations = {}

for i in range(len(hidden_layers)):
X = inputs if i == 0 else hidden_activations[i-1]
fan_in = X.shape[1]
fan_out = hidden_layers[i]

#W = np.random.randn(fan_in, fan_out) / np.sqrt(fan_in) # Xavier Init
W = np.random.randn(fan_in, fan_out) * 0.01 # Normal init

H = np.dot(X, W)
H = activations[nonlinearities[i]](H)

hidden_activations[i] = H

layer_means = [np.mean(H) for i, H in hidden_activations.items()]
layer_stds = [np.std(H) for i, H in hidden_activations.items()]

plt.figure()
plt.subplot(121)
plt.plot(list(hidden_activations.keys()), layer_means, 'ob-')
plt.title('layer mean')
plt.subplot(122)
plt.plot(list(hidden_activations.keys()), layer_stds, 'or-')
plt.title('layer std')

fig = plt.figure()
for i, H in hidden_activations.items():
ax = plt.subplot(1, len(hidden_activations), i+1)
ax.hist(H.ravel(), 30, range=(-1,1))
ax.grid(True)
for tic in ax.yaxis.get_major_ticks():
tic.tick1On = tic.tick2On = False
tic.label1On = tic.label2On = False
plt.title('Layer: %s \nMean: %1.4f\nStd: %1.6f' % (
i+1, layer_means[i], layer_stds[i]), fontsize=10)


The code above produces the following two graphs:

As we see in the higher layers, the distribution becomes highly concentrated around the mean as the standard deviation of the layer activations plummets. These near-zero activations subsequently lead to very small gradients during backpropagation that make it extremely difficult to train the network efficiently.

One of the ways we deal with this problem is by carefully initializing the model parameters. Glorot et. al suggested using the xavier initialization which corresponds to the weight matrix in the above code to be initialized as W = np.random.randn(fan_in, fan_out) / np.sqrt(fan_in). Using this initialization strategy, our layer activation statistics look like the following.

Although the standard deviations of the layer activations are still decreasing, they are doing so at a much more reasonable rate, which allows our network to train efficiently even for deeper architectures.

Batch Normalization solves this problem by transforming activations across all the layers to unit gaussian activations. This takes care of the issue of internal covariate shift and results in a number of benefits for our network such as:

1. Reduced training time. (due to improved gradient flow through the network)
2. The capability to handle higher learning rates. (which helps with fast convergence)
3. Acting as a model regularizer (All training examples in a batch get tied together due to batch normalization. This tends to have an overall regularizing effect)
4. Reducing the dependency on proper weights initialization.
5. Making saturating non-linearities viable options for use as activation functions.

The mechanics of Batch Normalization

In their paper, the authors normalize each scalar feature independently, by making the mean 0 and variance 1. If a layer has $d$ inputs with the corresponding input vector as $x = (x^{(1)}...x^{(d)})$, each dimension will be normalized as the following, with expectation and variance being computed over the training data set. $$\hat{x}^{(k)} = \frac{x^{(k)} - E[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$

Simply normalizing each input of a layer can restrict the representational power of the network. To address this, we have to allow the network to learn to undo the batch normalization transform by making sure the transformation can represent the identity transform. This is achieved by introducing a pair of parameters $\gamma^{(k)},\beta^{(k)}$ for each activation $x^{(k)}$ that can, respectively, scale and shift the normalized values as necessary ($\gamma^{(k)}$ and $\beta^{(k)}$ are learnable/trainable parameters)

With the above scale and shift parameters in hand, each normalized activation $\hat{x}^{(k)}$ can be viewed as an input to a sub-network composed of the linear transformation: $$y^{(k)} = \gamma^{(k)}\hat{x}^{(k)} + \beta^{(k)}$$ followed by the later processing done by the original network.

As mentioned, $\gamma^{(k)}$ and $\beta^{(k)}$ can be learned alongwith the model weights and serve to restore the representational power of the network. This can be seen if we set $\gamma^{(k)} = \sqrt{Var[x^{(k)}]}$ and $\beta^{(k)} = E[x^{(k)}]$. This way the model can recover the original activations if it learned that to be the optimal behaviour.

In practice, for any given mini-batch the normalization transformation is applied to each input dimension (activation) independently, with the activation first undergoing normalization ($\hat{x}$) and then a linear transformation ($y$).

For a mini-batch $B$ of size $m$,

$$\mu_{B} = \frac{1}{m}\sum_{i=1}^{m}x_{i}$$

$$\sigma_{B}^{2} = \frac{1}{m}\sum_{i=1}^{m}(x_{i} - \mu_{B})^{2}$$

$$\hat{x_{i}} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}$$

$$y_{i} = \gamma\hat{x_{i}} + \beta \equiv BN(x_{i})$$

where, $y_{i}$ is the Batch Normalizing Transform $BN(x_{i})$, $\hat{x_{i}}$ is the normalized activation and the values $\mu_{B}$ and $\sigma_{B}^2$ are the mini-batch mean and mini-batch variance respectively. Note: the value $\epsilon$ is a small constant added for numerical stability to make sure we do not perform any divide-by-zero operations.

Tensorflow implementation

Let's construct a deep neural network with 10 hidden layers containing 100 nodes each. We need to make sure that the initial weights for both the networks, with and without batch normalization are the same.

The tensorflow API handles downloading the MNIST dataset, extracting it, and preprocessing it into the right form for us. While constructing the computational graph for both models, the main implementation difference comes up while constructing the individual layers.

Normally we'd contruct a network layer by matrix multiplying the weights and the input tensors and then adding the bias values. For batch normalization, all of the above math implementations are handled for us by the tf.layers.batch_normalization API. If we wanted lower level control over the mechanics of the implementation, we could use the tf.nn.batch_normalization function instead (which would be a lot more work).

Note: We implement batch normalization before the output is fed into the activation function. We also don't explicitly add bias here because it's role is performed by the $\beta$ variable. For details see Section 3.2 of the original paper.

import tensorflow as tf
import tqdm
import numpy as np
import seaborn as sns # for nice looking graphs
from matplotlib import pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

np.random.seed(42)

def create_layer(
input_tensor,
weight,
name,
activation,
training=None):
w = tf.Variable(weight)
b = tf.Variable(tf.zeros([weight.shape[-1]]))
z = tf.add(tf.matmul(input_tensor, w), b, name='layer_input_%s' % name)
if name == 'output':
return z, activation(z, name='activation_%s' % name)
else:
return activation(z, name='activation_%s' % name)

def create_batch_norm_layer(
input_tensor,
weight,
name,
activation,
training):
w = tf.Variable(weight)
linear_output = tf.matmul(input_tensor, w)
batch_norm_z = tf.layers.batch_normalization(
linear_output, training=training, name='bn_layer_input_%s' % name)
if name == 'output':
return batch_norm_z, activation(batch_norm_z, name='bn_activation_%s' % name)
else:
return activation(batch_norm_z, name='bn_activation_%s' % name)

def get_tensors(
layer_creation_fn,
inputs,
labels,
weights,
activation,
learning_rate,
is_training):
l1 = layer_creation_fn(inputs, weights[0], '1', activation, training=is_training)
l2 = layer_creation_fn(l1, weights[1], '2', activation, training=is_training)
l3 = layer_creation_fn(l2, weights[2], '3', activation, training=is_training)
l4 = layer_creation_fn(l3, weights[3], '4', activation, training=is_training)
l5 = layer_creation_fn(l4, weights[4], '5', activation, training=is_training)
l6 = layer_creation_fn(l5, weights[5], '6', activation, training=is_training)
l7 = layer_creation_fn(l6, weights[6], '7', activation, training=is_training)
l8 = layer_creation_fn(l7, weights[7], '8', activation, training=is_training)
l9 = layer_creation_fn(l8, weights[8], '9', activation, training=is_training)
logits, output = layer_creation_fn(
l9, weights[9], 'output', tf.nn.sigmoid, training=is_training)

cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits))

correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

if layer_creation_fn.__name__ == 'create_batch_norm_layer':
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
learning_rate).minimize(cross_entropy)
else:
learning_rate).minimize(cross_entropy)

return accuracy, optimizer

def train_network(
learning_rate_val,
num_batches,
batch_size,
activation,
plot_accuracy=True):

inputs = tf.placeholder(tf.float32, shape=[None, 784], name='inputs')
labels = tf.placeholder(tf.float32, shape=[None, 10], name='labels')
learning_rate = tf.placeholder(tf.float32, name='learning_rate')
is_training = tf.placeholder(tf.bool, name='is_training')

np.random.seed(42)

scale = 1 if bad_init else 0.1

weights = [
np.random.normal(size=(784, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 100), scale=scale).astype(np.float32),
np.random.normal(size=(100, 10), scale=scale).astype(np.float32)]

vanilla_accuracy, vanilla_optimizer = get_tensors(
create_layer,
inputs,
labels,
weights,
activation,
learning_rate,
is_training)

bn_accuracy, bn_optimizer = get_tensors(
create_batch_norm_layer,
inputs,
labels,
weights,
activation,
learning_rate,
is_training)

vanilla_accuracy_vals = []
bn_accuracy_vals = []

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

for i in tqdm.tqdm(list(range(num_batches))):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)

sess.run([vanilla_optimizer], feed_dict={
inputs: batch_xs,
labels: batch_ys,
learning_rate: learning_rate_val,
is_training: True})

sess.run([bn_optimizer], feed_dict={
inputs: batch_xs,
labels: batch_ys,
learning_rate: learning_rate_val,
is_training: True})

if i % batch_size == 0:
vanilla_acc = sess.run(vanilla_accuracy, feed_dict={
inputs: mnist.validation.images,
labels: mnist.validation.labels,
is_training: False})

bn_acc = sess.run(bn_accuracy, feed_dict={
inputs: mnist.validation.images,
labels: mnist.validation.labels,
is_training: False})

vanilla_accuracy_vals.append(vanilla_acc)
bn_accuracy_vals.append(bn_acc)

print(
'Iteration: %s; ' % i,
'Vanilla Accuracy: %2.4f; ' % vanilla_acc,
'BN Accuracy: %2.4f' % bn_acc)

if plot_accuracy:
plt.title('Training Accuracy')
plt.plot(range(0, len(vanilla_accuracy_vals) * batch_size, batch_size),
vanilla_accuracy_vals, label='Vanilla network')
plt.plot(range(0, len(bn_accuracy_vals) * batch_size, batch_size),
bn_accuracy_vals, label='Batch Normalized network')
plt.tight_layout()
plt.legend()
plt.grid(True)
plt.show()


We'll run the train_network function and pass in different model parameters to compare how the network performs with and without batch normalization.

To begin with, let's start with a reasonable learning rate and the tanh non-linearity as our activation function.

train_network(0.01, 2000, 60, tf.nn.tanh)


We see that both networks perform reasonably well but the batch normalized network achieves a high accuracy much faster than our vanilla network. It maintains a slightly higher accuracy value for further iterations of batches. Both the networks would probably converge to a higher accuracy value if trained further.

Next, we train our networks similarly to the previous networks but this time, we initialize the weights to these networks sub-optimally. In practice, using zero-centered, small initial weight values give much better results for neural networks. Here, we sample the initial weights from a normal distribution with standard deviation of 1 (the weights in the previous experiment were initialized with a standard deviation of 0.1).

train_network(0.01, 5000, 60, tf.nn.tanh, bad_init=True)


We see that even with 5000 iterations, both networks perform much worse than our previous networks but even here the batch normalized network performs much better than our vanilla network. Notice the lower volatility displayed by the accuracy of the batch normalized network.

Next, we not don't just initialize the networks with bad weights, we also pass in an extremely high learning rate of 1. This learning rate is a hundred times the one we used in the previous runs.

train_network(1, 5000, 60, tf.nn.tanh, bad_init=True)


As expected, our vanilla neural network doesn't even get off the ground. The accuracy remains somewhere around 0.1 or 10 percent, which given 10 output classes to predict from, is as good as random guessing.

The batch normalized network however performs extremely well here. Even though the bad initial weights bog it down just the same as the vanilla network, it does really well when passed a high learning rate, leading to convergence at a high accuracy.

Let's try a different non-linearity next.

train_network(0.01, 2000, 60, tf.nn.relu)


For the ReLu non-linearity and a relatively small iteration limit of 2000 batches, both networks end up giving similar accuracies but the batch normalized network gets there faster and with much less volatility.

train_network(0.01, 5000, 60, tf.nn.relu, bad_init=True)


If we use a bad weight initialization strategy with the ReLu non-linearities, even with a small learning rate our vanilla network does not perform at all. Batch normalization seems to be an effective way to counter a bad weight initialization strategy. Note how the accuracy of the vanilla ReLu network does not seem to change at all. ReLu networks sometimes suffer from units dying out mid training. It is highly likely that this is what we're seeing here.

Batch normalization doesn't just guard us from an ill-specified model with suboptimal hyperparameters, it maximizes the performance that a good model can give us.

The mechanics of model inference when using batch normalization are slightly different. Instead of using the batch mean and variance, we need to use the population mean and variance when performing inference. These statistcs are kept updated by the context:

with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):

When we're done training and want to use our batch normalized model for predictions, we need to specify to the computational graph that we're not training. This is achieved by passing in the value of the is_training placeholder as False in the feed_dict.