BlogsVideosSign UpLogin

Understanding Vanilla RNN

By Aryan Raut1/22/202686 views
Understanding Vanilla RNN

Machine Learning. Deep Learning. Representational Learning.

Three distinct concepts, three different mathematics, yet one phenomenon binds them: the phenomenon of LEARNING.

What Does Learning Actually Mean?

Learning refers to the process of finding parameters (popularly weights and bias) that provides optimized results, i.e., minimizes the loss function of the given model.

Loss is the difference between the value predicted by model and the actual value present.

The Foundation: Data

The basis of learning, or the raw material, is the data that we use during training a model. As we are aware, the data can be of different categories altogether:

  • Spatial data
  • Temporal data
  • Independent numerical data

Specialized Neural Network Architectures

Different model architectures are invented to generalize a category of data. Neural Networks, or ANNs (Artificial Neural Networks), work well with independent data, either numerical or categorical.

However, these ANNs cannot be trusted with specialized categories of data such as:

  • Spatial Data: Data that contains information about the location, shape, and spatial relationship of objects in physical space

  • Temporal Data: Data that contains information about objects with temporal dependency between them; in other words, data that is indexed by time, where the sequence or timing of observations are essential to interpretation and prediction

Architecture Selection by Data Type

For these specialized categories of inputs, we use specialized classes of neural network architecture:

CNN (Convolutional Neural Networks)

  • Data Type: Spatial data
  • Common Application: Image-based tasks

RNN (Recurrent Neural Networks)

  • Data Type: Temporal data
  • Common Application: Text-based tasks

Note: Image and text here are just popular applications of CNN and RNN; however, they indeed have wider applications.


We'll talk about RNN in this article.

Vanilla RNN Architecture and Backpropagation through time

A vanilla RNN is a uni-directional RNN that captures sequential information from left-right or right-left. Let us understand the architecture with the help of diagrams below.

RNN Unit

RNN Architecture

From the above diagram, we re-iterate the presence of two inputs for each unit: static & temporal. In this architecture, we are working on a training example where the entire sequence is already available, and there is no need to generate the next-in-sequence datapoint. As a result, x<t>x^{<t>} is not dependent on previous unit y<t1>y^{<t-1>}, as is in the case of sampling, where the output of previous unit (t1)(t-1) is used as an input for the subsequent unit (t)(t).

Mathematical Interpretation

A. Forward Propagation

Assuming no of inputs [Tx][T_x] = Assuming no of inputs [Ty][T_y], we compute two components:

  1. Hidden state value: a<t>a^{<t>}
  2. Output unit: y^<t>\hat{y}^{<t>}

For given parameters, WaaW_{aa}, WaxW_{ax} & bab_a corresponding to Hidden state and WyaW_{ya} & byb_y corresponding to y^\hat{y}:

For every time-step 't':

a<t>=g[Waaa<t1>+Waxx<t>+ba]a^{<t>} = g[W_{aa} \cdot a^{<t-1>} + W_{ax} \cdot x^{<t>} + b_a]

y<t>=g[Wyaa<t>+by]y^{<t>} = g[W_{ya} \cdot a^{<t>} + b_y]

We can further simplify equation (1) corresponding to a<t>a^{<t>} by compressing two parameters WaaW_{aa} & WaxW_{ax} into WaW_a:

Wa[Waa:Wax]W_a \sim [W_{aa} : W_{ax}]

here the matrices WaaW_{aa} & WaxW_{ax} are stacked horizontally

If dim(Waa)=(100,100)\dim(W_{aa}) = (100, 100) & dim(Wax)=(100,10000)\dim(W_{ax}) = (100, 10000), then

dim(Wa)=(100,(100+10000))=(100,10100)\dim(W_a) = (100, (100+10000)) = (100, 10100)

Using WaW_a in a<t>a^{<t>}:

a<t>=g[Wa(a<t1>,x<t>)+ba]a^{<t>} = g[W_a(a^{<t-1>}, x^{<t>}) + b_a]

Now, [a<t1>,x<t>]p[a^{<t-1>}, x^{<t>}] \rightarrow p

p[a<t1>x<t>]p \sim \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}

here matrices a<t1>a^{<t-1>} & x<t>x^{<t>} are vertically stacked

Finally:

a<t>=g[Wap+ba]a^{<t>} = g[W_a \cdot p + b_a]

B. Backpropagation Through Time

  • Similar to ANN, the real learning happens in backpropagation, where we minimize the loss function by finding optimum values for the given parameters.
    • We use gradient descent to reach the optimal values
  • To simplify the understanding, let us assume there is only one output for the network, that is present at timestep TxT_x.
    • The loss function as defined below depends only on one value of y^\hat{y}, instead of y^\hat{y} at every time step

Loss Function Diagram

We use cross entropy loss:

L[y^,y]=ylog(y^)(1y)log(1y^)\mathcal{L}[\hat{y}, y] = -y \cdot \log(\hat{y}) - (1-y) \cdot \log(1-\hat{y})

In case we predict an outcome at every timestep:

L[y^,y]=L<t>[y^<t>,y<t>], where ttimestep\mathcal{L}[\hat{y}, y] = \sum \mathcal{L}^{<t>}[\hat{y}^{<t>}, y^{<t>}], \text{ where } t \rightarrow \text{timestep}

Following are the parameters we will optimize:

  • WaaWhW_{aa} \rightarrow W_h: weights for hidden units
  • WaxWiW_{ax} \rightarrow W_i: weights for input units
  • WyaWoW_{ya} \rightarrow W_o: weights for output unit

Applying gradient descent for learning rate α\alpha :

Wh=WhαLWhW_h = W_h - \alpha \cdot \frac{\partial \mathcal{L}}{\partial W_h}

Wi=WiαLWiW_i = W_i - \alpha \cdot \frac{\partial \mathcal{L}}{\partial W_i}

Wo=WoαLWoW_o = W_o - \alpha \cdot \frac{\partial \mathcal{L}}{\partial W_o}

Our goal is to find LWh\frac{\partial \mathcal{L}}{\partial W_h}, LWi\frac{\partial \mathcal{L}}{\partial W_i}, LWo\frac{\partial \mathcal{L}}{\partial W_o} values to calculate optimal WhW_h, WiW_i, WoW_o


I. Calculating WoW_o

The given mind map shows the dependency mapping of functions, which helps to visualise the chain rule to calculate the gradient:

Dependency Mapping for W_o From the above mapping we can calculate:

LWo=Ly^y^Wo\frac{\partial \mathcal{L}}{\partial W_o} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial W_o}

using chain rule


II. Calculating LWi\frac{\partial \mathcal{L}}{\partial W_i}

Dependency Mapping for W_i

From the above dependency mapping, we have 3 simultaneous dependencies that occur as we move back in timestep from a<3>a^{<3>} to a<1>a^{<1>}, which can simply be added together, after calculating the gradient for each timestep.

For timestep 3:

[LWi]3=Ly^y^a<3>a<3>Wi\left[\frac{\partial \mathcal{L}}{\partial W_i}\right]_3 = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{<3>}} \cdot \frac{\partial a^{<3>}}{\partial W_i}

For timestep 2:

[LWi]2=Ly^y^a<3>a<3>a<2>a<2>Wi\left[\frac{\partial \mathcal{L}}{\partial W_i}\right]_2 = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{<3>}} \cdot \frac{\partial a^{<3>}}{\partial a^{<2>}} \cdot \frac{\partial a^{<2>}}{\partial W_i}

For timestep 1:

[LWi]1=Ly^y^a<3>a<3>a<2>a<2>a<1>a<1>Wi\left[\frac{\partial \mathcal{L}}{\partial W_i}\right]_1 = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{<3>}} \cdot \frac{\partial a^{<3>}}{\partial a^{<2>}} \cdot \frac{\partial a^{<2>}}{\partial a^{<1>}} \cdot \frac{\partial a^{<1>}}{\partial W_i}

Combining all three:

LWi=[LWi]3+[LWi]2+[LWi]1\frac{\partial \mathcal{L}}{\partial W_i} = \left[\frac{\partial \mathcal{L}}{\partial W_i}\right]_3 + \left[\frac{\partial \mathcal{L}}{\partial W_i}\right]_2 + \left[\frac{\partial \mathcal{L}}{\partial W_i}\right]_1

LWi=Ly^y^a<3>a<3>Wi+Ly^y^a<3>a<3>a<2>a<2>Wi+Ly^y^a<3>a<3>a<2>a<2>a<1>a<1>Wi\frac{\partial \mathcal{L}}{\partial W_i} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{<3>}} \cdot \frac{\partial a^{<3>}}{\partial W_i} + \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{<3>}} \cdot \frac{\partial a^{<3>}}{\partial a^{<2>}} \cdot \frac{\partial a^{<2>}}{\partial W_i} + \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{<3>}} \cdot \frac{\partial a^{<3>}}{\partial a^{<2>}} \cdot \frac{\partial a^{<2>}}{\partial a^{<1>}} \cdot \frac{\partial a^{<1>}}{\partial W_i}

Generalizing this for timesteps TxT_x:

LWi=j=1TxLy^y^a<j>a<j>Wi\frac{\partial \mathcal{L}}{\partial W_i} = \sum_{j=1}^{T_x} \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{<j>}} \cdot \frac{\partial a^{<j>}}{\partial W_i}


III. Calculating LWh\frac{\partial \mathcal{L}}{\partial W_h}

Dependency Mapping for W_h Using the same dependency mapping from the previous calculation, we again have multiple simultaneous dependencies as we go down the timesteps.

Generalizing this for timesteps TxT_x:

LWh=j=1TxLy^y^a<j>a<j>Wh\frac{\partial \mathcal{L}}{\partial W_h} = \sum_{j=1}^{T_x} \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{<j>}} \cdot \frac{\partial a^{<j>}}{\partial W_h}

Drawbacks of Vanilla RNN

The major drawback of the RNN architecture is that it cannot work well on sequences with Long Range Dependency, as it suffers from the well-known Vanishing Gradient Problem.

Understanding Long Range Dependency

→ Sometimes in a language, sentences are framed in such a way that they have Long Range Dependencies, which means a word which comes earlier in a sentence influences what needs to come much later in that sentence. Long Range Dependencies

→ A Vanilla RNN architecture finds it difficult to memorize the context from earlier words onto words that come much later. alt text

Why This Happens

  • This is because in this architecture, the output at a timestep tt is closely associated by neighbors of tt and not so much by distant neighbors

The vanishing gradient problem prevents vanilla RNNs from effectively learning dependencies that span many timesteps.

Therefore in order to capture the long range context, we introduce the concept of ‘Context Cell’ or ‘Memory Cell’ which acts as an additional input to these sequential units. This forms the basis of subsequent architectures in sequential modelling namely, Gated Recurrent Unit and Long Short Term Memory.

Deep LearningSequential ModelsRNN