Explain it to me like a 5-year-old: Introduction to LSTM and Attention Models — Part 2/2
I am a Product Manager with some background in Deep Learning and Data Science. Deep Learning is daunting and I am trying to make it as intuitive as possible. Feel free to reach out to me or comment if you think any of the content below needs correction-I am open to feedback.
My intention here is to make understanding Deep Learning easy and fun instead of mathematical and comprehensive. I haven’t explained everything in detail, but I am sure by the end of the blog, you will be able to understand how an LSTM works and you will NEVER FORGET IT!!! ;)
Feel free to connect with me on LinkedIn!
Long Short Term Memory (LSTM) Networks:
As we saw previously in the case of simple RNNs, if the sequence of RNNs is big, it causes gradient issues primarily because while backpropagating the information from last NN to first NN, there might be some data loss in their hidden state. Thus it makes simple RNNs less effective when dealing with large sentences and remembering the context for the same.
In order to tackle that issue, we use LSTM networks. They are nothing but modified versions of simple NN but instead of using a single non-linear function and one hidden state to do all the work (Fig 3.1), we have multiple non-linear functions which help us control the flow of information and one cell state in addition to a hidden state (Fig 3.2). In this way, the hidden state doesn’t have to do all the work when it comes to backpropagation
The information in LSTM is controlled with the help of gates. Gates optionally let information through.
How do LSTMs work?
These are the 4 steps that primarily captures the working of each LSTM cell:
Each cell starts by forgetting all the irrelevant information w.r.t. previous state. As you can see in the figure, previous hidden state and current input are passed through the sigmoid function which modulates how much should be kept in vs how much should be kept out (sigmoid outputs 0 or 1) and store it in the cell state (c(t))
Here we determine what part of old information and what part of new information is relevant and store it into the cell state.
The most critical part of LSTMs is that they maintain a different value for the cell state (c(t)) along with its hidden state which is selectively updated through these gates.
Finally, we have an output from LSTM which is also called a hidden state. The output gate controls what amount of encoded information from the cell state should eventually flow to the next cell via its current hidden state
Key takeaway: LSTMs control the flow of information within each cell state by maintaining a cell state and hidden state which helps them solve the issue of vanishing gradient.
Backpropagation in LSTMs:
Few mins back we observed how backpropagation was a problem in simple RNN as there is a risk of vanishing gradients and we talked about LSTMs able to help in solving this. But how? As we saw, there is a hidden state and also a cell state in LSTM — backpropagation is only applied to the cell state which is updated selectively using the gated mechanisms hence we have an uninterrupted gradient flow.
Intuition: Basically you can view the hidden state as city roads and cell state as highways. So if you want to go from New York to Boston, would you prefer using the city roads which will take up to 9hrs to reach Boston, or use highways that can get you to Boston in 4hrs? Hence LSTMs work better because they use cell states (highways) for backpropagation.
One of the interesting applications of RNNs is in machine translation; translating a sentence from one language to another.
In the above diagram you can see the architecture of how RNNs can be used for machine translation. The input we provide to RNN is encoded as a state vector using an encoder, after performing the computations, the state vector is then decoded as the output of the RNN using a decoder.
But there are few issues in using RNNs for machine translation:
- Encoding bottleneck:
As you can see there is only one line that connects the encoder to a decoder which means all the information from the encoder needs to captured in one state and then sent to the decoder — in the case of large sentences, it will be difficult to capture all the information in a single state which might result in data loss.
- Slow, sequential model:
As this is a sequential model, there are no opportunities to have a parallel process for computation which leads to slow training time for the model
- Not long memory:
I know we tackled the problem of handling large sentences by switching from simple RNNs to LSTM but when it comes to machine translation, there are very long temporal dependencies that are difficult for LSTMs to handle.
So how do we deal with such dependencies in large sequences and bodies of text? Attention!
Instead of the decoder having access to only the final state of the encoder, attentions can be used to grant decoder the access to states after each of the timesteps in the original sentence, and it’s the weighting of these vectors that the model will be learning throughout its training. The attention mechanisms in NN provide learnable memory access. Basically, the model is trained to pay attention to only those parts of the inputs which have a high impact on the output.
Intuition: If I ask you “What is your name?” and you reply back saying “My first name is Ameya. I live in New York and love dancing. I forgot to mention that my last name is Shanbhag.” So the model will give more weight (attention) to the first and last sentence because that is what is more important!
That’s all folks!
Feel free to jump to https://www.youtube.com/watch?v=qjrad0V0uJE&ab_channel=AlexanderAmini lecture by Ava :)
Do check out my other posts to gain more knowledge about finance and technology.
Please do let me know if there is any other concept in deep learning you want me to write an article on, I will try my best to explain it in simpler terms.
Also, feel free to ask questions in the comment section. Will be happy to help you out :)
PS: The analogy I have used might not be 100% correct but it’s easy to understand things with a simpler analogy.
Credits: MIT Open Courseware