Cindered Thoughts

Formulations of Neural Net Weight initializations

This write-up is a repost of my personal notes on weight initialization, and how it impacts the outputs of an MLP. I use this reference for myself, I hope this can also be of help to others :)

Initializing Weights of the MLP

The initialization of neural network weights using a standard deviation of 1/n (where n is the number of input neurons, also known as fan-in) is a strategy designed to maintain the variance of the outputs of each neuron at initialization. Let's delve into a mathematical explanation and derivation of why this specific value is chosen, particularly in the context of keeping the variance of the outputs stable.

Background

When we initialize the weights of a neural network, we want to ensure that the signal (i.e., the output of each neuron before applying the activation function) does not vanish or explode as it propagates through the network. This stability helps in maintaining effective gradient propagation during training.

Assumptions

Output Variance Calculation

Consider a neuron's output z before applying the activation function, calculated as: z=i=1nwixi where wi are the weights and xi are the inputs.

Step 1: Calculate the Variance of z

Since the weights and inputs are independent and assuming the inputs also have zero mean, the variance of the product wixi for each i is simply the product of their variances (due to independence and zero means): Var(wixi)=Var(wi)·Var(xi)=σ2·1=σ2 (as we noted for simplicity xi has Var(xi)=1)

Since z is the sum of n such independent terms wixi, the variance of z is the sum of their variances: Var(z)=Var(i=1nwixi)=i=1nVar(wixi)=nσ2

Step 2: Desired Variance of $ z $

To maintain the variance of the output $ z $ similar to the variance of the input across layers, we would like Var(z)=1. This condition helps prevent the vanishing or exploding gradients during training.

Setting Var(z)=1: nσ2=1 σ2=1n

Therefore, the standard deviation σ should be: σ=1n

Conclusion

This derivation shows that setting the standard deviation of the weight initialization to 1/n ensures that the output of each neuron has a variance of 1, assuming the inputs also have a variance of 1. This balance is crucial for maintaining effective learning, as it prevents the scale of the neuron outputs from increasing or decreasing dramatically across layers, which can lead to numerical instability or poor convergence. This is why the 1/n factor is commonly used in weight initialization methods like Xavier/Glorot initialization (which adjusts the variance further based on both the number of inputs and outputs).


Proving the identities of linear transformations under Random Variables

Let's go through the mathematical proof for how the linear transformations of a random variable affect its mean and variance. The transformation we are considering is Y=aX+b, where X is a random variable with mean $ \mu_X $ and variance σX2, and $ a $ and b are constants.

1. Expectation (Mean)

The expectation operator E has the properties of linearity, which means that for any constants a and $ b ,andarandomvariableX$: E[aX+b]=aE[X]+b

Proof for Mean

Given that X has a mean of μX, the mean of Y is calculated as follows: E[Y]=E[aX+b] E[Y]=aE[X]+b E[Y]=aμX+b

Thus, the mean of Y is aμX+b.

2. Variance

Variance, denoted as Var, measures the spread of a random variable around its mean. The variance of a transformed random variable Y=aX+b is defined as: Var(Y)=E[(YE[Y])2]

Proof for Variance

Substituting $ Y = aX + b $ and $ E[Y] = a\mu_X + b $ into the variance formula: Var(Y)=E[(aX+b(aμX+b))2] Var(Y)=E[(aXaμX)2] Var(Y)=E[a2(XμX)2] Var(Y)=a2E[(XμX)2]

Since E[(XμX)2] is the definition of Var(X), or $ \sigma_X^2 $: Var(Y)=a2σX2

Key Insight

The addition of a constant b shifts the mean but does not affect the spread or variability of the distribution, hence it does not influence the variance. The multiplication by a, however, scales the spread of the distribution by a2.

Summary

This proof shows that the mean and variance of a linear transformation of a random variable Y=aX+b are aμX+b and a2σX2 respectively. These properties are foundational in probability and statistics and are extensively utilized across fields like data science, economics, and engineering to understand and predict the behavior of complex systems based on simpler underlying distributions.


Putting It Together

Applying the Transformation

Given Y=aX+b, we can substitute a with our target standard deviation σT.

Y=σTX

We know now that: Var(Y)=σY2=σ2σX2 and by square root, we get: σY=Var(Y)=σT2σX2=σTσX.
=σT*1
We substitute 1/n and our origin standard deviation 1 to get: σY=1/n*1

This concludes the formulation of MLP weight initialization. Hope this provides a good reference for folks to think about how weight initialization impacts the variance of a network's outputs.