Skip to content

Latest commit

 

History

History

batch_normalization

Batch Normalization

Implementation of Batch Normalization in Tensorflow, Batch Normalization is a strategy to address the problem of internal covariate shift, the description of internal covariate shift is that for deep neural networks, the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change, which slows down the training by requiring lower learning rates and careful parameter initialization, and makes it notoriously hard to train models with saturating nonlinearities.

Batch Normalization draws its strength from making normalization a part of the model architecture and performing the normalization for each training mini-batch, which also allows to use much higher learning rates and be less careful about initialization, as well as acts as a regularizer, in some cases eliminating the need for Dropout.

This repository is inspired by the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift and tomokishii/mnist_cnn_bn.py. Improvement and modifications are done to make the codes more readable and flexible.

More details about Batch Normalization please go through the Reference, which contains the information I read when I learned this technique.

Training Results:

 Training...
  step     0: validation loss = 22.7861, validation accuracy = 0.1146
  step   200: validation loss = 0.1514, validation accuracy = 0.9562
  step   400: validation loss = 0.1149, validation accuracy = 0.9684
  ...
  step  1600: validation loss = 0.0627, validation accuracy = 0.9800
  step  1800: validation loss = 0.0553, validation accuracy = 0.9836
  step  2000: validation loss = 0.0501, validation accuracy = 0.9852
  ...
  step  3800: validation loss = 0.0426, validation accuracy = 0.9880
  step  4000: validation loss = 0.0388, validation accuracy = 0.9894
  step  4200: validation loss = 0.0373, validation accuracy = 0.9900
  ...
  step  5000: validation loss = 0.0342, validation accuracy = 0.9898

 Testing...
  test loss = 0.0295, test accuracy = 0.9894, multiclass log loss = 0.0295

Reference