Formulations of Neural Net Weight initializations
This write-up is a repost of my personal notes on weight initialization, and how it impacts the outputs of an MLP. I use this reference for myself, I hope this can also be of help to others :)
Initializing Weights of the MLP
The initialization of neural network weights using a standard deviation of (where is the number of input neurons, also known as fan-in) is a strategy designed to maintain the variance of the outputs of each neuron at initialization. Let's delve into a mathematical explanation and derivation of why this specific value is chosen, particularly in the context of keeping the variance of the outputs stable.
Background
When we initialize the weights of a neural network, we want to ensure that the signal (i.e., the output of each neuron before applying the activation function) does not vanish or explode as it propagates through the network. This stability helps in maintaining effective gradient propagation during training.
Assumptions
- The weights are initialized independently from a normal distribution with mean 0 and standard deviation .
- Each neuron receives inputs , which are also assumed to be independent and have a mean of 0 and some constant variance (say, variance = 1 for simplicity).
Output Variance Calculation
Consider a neuron's output before applying the activation function, calculated as: where are the weights and are the inputs.
Step 1: Calculate the Variance of
Since the weights and inputs are independent and assuming the inputs also have zero mean, the variance of the product for each is simply the product of their variances (due to independence and zero means): (as we noted for simplicity has )
Since is the sum of such independent terms , the variance of is the sum of their variances:
Step 2: Desired Variance of $ z $
To maintain the variance of the output $ z $ similar to the variance of the input across layers, we would like . This condition helps prevent the vanishing or exploding gradients during training.
Setting :
Therefore, the standard deviation should be:
Conclusion
This derivation shows that setting the standard deviation of the weight initialization to ensures that the output of each neuron has a variance of 1, assuming the inputs also have a variance of 1. This balance is crucial for maintaining effective learning, as it prevents the scale of the neuron outputs from increasing or decreasing dramatically across layers, which can lead to numerical instability or poor convergence. This is why the factor is commonly used in weight initialization methods like Xavier/Glorot initialization (which adjusts the variance further based on both the number of inputs and outputs).
Proving the identities of linear transformations under Random Variables
Let's go through the mathematical proof for how the linear transformations of a random variable affect its mean and variance. The transformation we are considering is , where is a random variable with mean $ \mu_X $ and variance , and $ a $ and are constants.
1. Expectation (Mean)
The expectation operator has the properties of linearity, which means that for any constants and $ b X$:
Proof for Mean
Given that has a mean of , the mean of is calculated as follows:
Thus, the mean of is .
2. Variance
Variance, denoted as , measures the spread of a random variable around its mean. The variance of a transformed random variable is defined as:
Proof for Variance
Substituting $ Y = aX + b $ and $ E[Y] = a\mu_X + b $ into the variance formula:
Since is the definition of , or $ \sigma_X^2 $:
Key Insight
The addition of a constant shifts the mean but does not affect the spread or variability of the distribution, hence it does not influence the variance. The multiplication by , however, scales the spread of the distribution by .
Summary
This proof shows that the mean and variance of a linear transformation of a random variable are and respectively. These properties are foundational in probability and statistics and are extensively utilized across fields like data science, economics, and engineering to understand and predict the behavior of complex systems based on simpler underlying distributions.
Putting It Together
Applying the Transformation
Given , we can substitute with our target standard deviation .
We know now that:
and by square root, we get:
.
We substitute and our origin standard deviation to get:
This concludes the formulation of MLP weight initialization. Hope this provides a good reference for folks to think about how weight initialization impacts the variance of a network's outputs.