Understanding Recurrent Neural Networks (RNN)
- 13 minutes read - 2723 wordsIntroduction
Recurrent Neural Networks (RNNs) are a specific type of Neural Networks (NNs) that are especially relevant for sequential data like time series, text, or audio data. Traditional neural networks process each input independently, meaning they cannot retain information about previous inputs. This makes them ineffective for tasks that require understanding sequences, such as time series forecasting or natural language processing. RNNs however, process the data sequentially, which enables them to remember data from the past.
RNN Architecture
In a standard Feedforward Neural Net (FFNN) all data is processed in parallel. As discussed in Introduction to Deep Learning a FFNN consists of an input layer, an output layer and in between a number of hidden layers. All the outputs are calculated independently and there is no connection between them. An RNN in contrast uses the output of one step as input of the next step in addition to the input data and in that way creates a connection and a memory to data of previous steps. This way the RNN is able to generate temporal dependencies in a time series or sequential context dependencies. The difference in the architecture is illustrated in the following plot.
Illustration of a Recurrent Neural Network vs. a standard Feedforward Artificial Neural Network
At each time step, the RNN takes in the current input along with the hidden state from the previous step. The hidden state acts as a form of memory, allowing the model to retain information from earlier inputs. For instance, in the context of a time series, the current input could represent a specific time step, and the hidden state would contain information from earlier time steps. Similarly, in a sentence, the current input could be a word, and the hidden state would store information from the words preceding it. This architecture is known as Recurrent Unit and can be seen in the next plot. Both illustrations show an RNN. The one on the right hand side shows the “unfolded” network. In this plot the steps are illustrated after each other, which makes the recurrent nature clearer.
RNN architecture
Unlike FFNNs, which have distinct weights for each node, RNNs use the same set of weights repeatedly within each layer. That is the number of parameters they need to learn is reduced, making them more efficient for sequential data. In the above plot these are illustrated as
The Math behind RNNs
In the above sections we learned how RNN work conceptionally. It processes sequential data by maintaining a hidden state that carries information from previous time steps. Now we will have a closer look at the calculations that need to be done to achieve this.
1. Hidden State Update
At each time step
with
2. Output Calculation
The output
with
For any task that needs a probability as an output, the softmax function is used to map the output onto a probability distribution for each class. This might be a classification task or a Natural Language Processing (NLP) task, such as sequence generation (like text generation). In the latter, the output might represent the probability of the next token or word in a sequence rather than a class label. The softmax function is defined as
Using this transformation ensures that the output data lies between
For a regression task such a transformation is not necessary. We can skip the activation function and calculate the output as
The equations can be illustrated in the following diagram.
RNN Cell
3. Backpropagation through time
To train an RNN, we use a variant of backpropagation called Backpropagation Through Time (BPTT). The loss function
with
In RNNs, this means computing gradients of the loss with respect to both the output layer weights
4. Vanishing and Exploding Gradients
One challenge in training RNNs is the vanishing and exploding gradients problem. When backpropagating through many time steps, gradients can shrink exponentially (vanishing gradient) or grow uncontrollably (exploding gradient). If gradients become too small, early time steps barely get updated. On the other hand, if gradients become too large, training becomes unstable.
To mitigate this, techniques like gradient clipping, Long Short-Term Memory (LSTM) networks, or Gated Recurrent Units (GRU) are often used. Explaining these techniques is however not in the scope of this post.
Simple Example
Let’s consider a simple example to calculate the forward pass for an RNN.
An RNN cell updates the hidden state (
With the previous hidden state
and the input at time
Step 1: Compute the hidden state
In this example we use as activation function
with the weights defined above, we get
Step 2: Compute the output
The output is calculated as
using the weights defined above, we get
Which is the next output.
Types of RNNs
In this paragraph some variants of RNNs are briefly presented without explaining them in detail.
Bidirectional RNNs
Instead of processing data only forward in time, Bidirectional RNNs (BiRNNs) pass information both forward and backward, allowing the network to use future context when making predictions. This is particularly useful for NLP tasks like named entity recognition (NER). When using this variant we need to make sure that the future context will be available, when using the model to make predictions. When for example predicting the weather for the next hours, this is not a model variant we can use.
Long Short-Term Memory (LSTM)
LSTMs were developed to overcome issues vanilla RNN showed, especially in the context of vanishing gradients. LSTMs introduce gates (input, forget, and output gates) to control information flow and prevent vanishing gradients. This enables a longer memory than vanilla RNNs provide, which makes them more effective for tasks like speech recognition and text generation, where a lot of context is needed.
Gated Recurrent Units (GRU)
GRUs were developed after LSTMs and simplify them by combining the forget and input gates into a single update gate, reducing computational cost while still handling long-term dependencies effectively. They are often preferred when training speed is a priority.
Input-Output Structures in RNNs
RNNs are versatile and can handle a variety of input-output structures depending on the task. The main types are:
One-to-One
This is the simplest structure, where one input produces one output. In this case, the RNN behaves similarly to a feedforward neural network since it does not leverage sequential dependencies across multiple time steps. An example would be time series prediction, where a single time step is used to predict the next time step.
Example: Basic time series forecasting where only the previous value is used to predict the next (though most real-world models use multiple past values, making it many-to-one).
One-to-Many
A single input generates a sequence of outputs. This structure is useful in applications where a single source must produce a series of predictions.
Example: Image captioning, where a single image is processed, and the RNN generates a sequence of words to describe it. Typically, a Convolutional Neural Network (CNN) extracts features from the image, which are then fed into an RNN (such as an LSTM) to produce a descriptive caption.
Many-to-One
In this structure, a sequence of inputs produces a single output. This is common in tasks that require analyzing an entire sequence before making a final prediction.
Example: Sentiment analysis, where an RNN processes a sequence of words in a customer review and outputs a sentiment classification (e.g., positive, negative, or neutral).
Many-to-Many
This structure maps a sequence of inputs to a sequence of outputs. It can either be synchronized (where input and output lengths match) or asynchronous (where the input and output lengths may differ).
Example: Machine translation, where an RNN translates a sentence from one language to another, producing a sequence of words in the target language.
The following table summarizes the different types of input-output structures in RNNs:
RNN types summarized
The different input and output types can be illustrated as follows.
RNN types illustrated
By choosing the appropriate RNN architecture and input-output structure, it is possible to tackle a wide range of complex sequence-based problems, from time series prediction to natural language processing.
RNNs in Python
Finaly, we will create a Recurrent Neural Network (RNN) in Python. For this demonstration, we will use PyTorch, a widely used deep learning framework. The concepts we’ve covered so far will be applied to a simple toy dataset.
The training data consists of a basic sequence of numbers, and the task is to predict the next number in the sequence. The following code outlines how to build and train an RNN using PyTorch to accomplish this task. We will implement a basic RNN with a single hidden layer. The RNN model is defined by inheriting from the nn.Module class, and the nn.RNN class is used to create the RNN unit.
The code appears longer than it actually is due to the detailed comments, which are provided within the code for further clarification. The code snipped although a very simplified dataset, is a full working example.
|
|
Summary
In summary, RNNs are powerful tools for handling sequential data. They are a class of artificial neural networks designed to process sequences by maintaining a memory of previous inputs. Unlike traditional feedforward networks, RNNs have loops that allow information to persist, making them suitable for tasks where context from earlier steps is essential. While effective for many applications, RNNs face challenges like vanishing gradients, which hinder learning long-term dependencies. Variants such as Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs) address these issues with improved memory mechanisms.
Transformers have emerged as a dominant alternative to RNNs. By using a self-attention mechanism instead of sequential processing, transformers can model long-range dependencies more efficiently and support parallel computation. This makes them faster to train and more effective for tasks like natural language processing, where they have largely replaced RNN-based models. This is however out of the scope of this article.
If this blog is useful for you, please consider supporting.