How to compute the Fisher of a conditional when applying natural gradient to neural networks?
October 29, 2018
This short note aims at explaining how we come up with an expression for the Fisher Information Matrix in the context of the conditional distributions represented by neural networks.
In neural networks, the so called natural gradient is a preconditioner for the gradient descent algorithm, where the update is regularized so that each update of the values of the parameters will be measured using the divergence. This has some interesting properties, such as the effect of making the update invariant to reparametrization of our neural network: more explanation to come in another blog post. The update is given by:
- the expectation is taken using (discrete) samples of the training set ;
- is our neural network with the input (e.g. the pixels of an image), and the output (e.g. the 10 coefficients of the softmax for MNIST where we have 10 classes = 10 digits);
- we use the negative log likelihood as our loss function , and so is the gradient of our loss with respect to the parameters ;
- is the Fisher Information Matrix (FIM) , defined as:
The link between the and the FIM resides in the fact that the FIM is the second order term of the Taylor series expansion of the : For a distribution it is given by:
where is negligible compared to when is small, the first order term is .
This is the general definition for , using a density . But when applying this technique to train neural networks, we model the conditional . So how do we apply this to neural networks training, i.e. for the conditional ?
Here is my explanation.
Instead of just considering we will use the joint probability . We have introduced which is the distribution over the inputs. If the task is image classification, this is the distribution of the natural images . Usually we do not have access to explicitely, but instead we have samples from it, which are the images in our training set.
By replacing with in the formula above, we can consider and write the FIM for this joint distribution:
Next we replace the joint with the product of the marginal over and the conditional in the derivative:
and since does not depend on then . This simplifies in:
Equivalently for the expectation, we can take the expectation in 2 steps:
- sample a from our training distribution;
- for this value of compute then sample multiple points to estimate the expectation over . Here we also require multiple backprops to compute the gradients for each sample .
Finally we get the desired formula:
And so we get the FIM for a conditional distribution.