RNN, LSTM, Encoder, and Beyond

Language Models in Action, NLP

Suppose one day you want to home-brew your own SIRI for your daily assistance. How can it be done? Well, surprisingly not too difficult. In this post, we will take the basics of home-brewing SIRI into action.

SIRI

Of course, given this is a short reading, there will be no way to explain everything in fine detail such as voice embedding extraction, Mel-Spectrogram analysis, text-to-speech voice synthesis, time-delayed neural network, or all the fancy language models used in this technology. Instead, we will focus on one core area: how your cell phone understands your questions and provides some reasonable answer?

The basic idea is, suppose we have a question How do you feel? For the trained model, it would be expected to have an answer like I feel…(happy?), which sounds like the question as if answered by a human. It sounds non-trivial and daunting at a first glance but turns out can be quite easy, if you choose the right deep learning architectures.

For any of the deep learning tasks, the first and foremost question is how to model this and fit it into some existing solutions for analysis? As all of us know, computers like numbers, and it turns out this is all you need for this task: using numbers to represents each word and using fix-length of an array to represent a sentence, which is called one-hot encoding.

Here 1 encodes the word “how”, 2 encodes the word“do”, and so on. 0 is a special representation for padding to fix all sentences in the same length. Please note that the pre-fixed length dictates the capability of your model. In most cases, a vector length of 8 means your model can generate answers with length at most 8.

And so we turned all data into vectors, and training pairs such as Q: How do you feel?/A: I feel great! are turned into vectors such as [1, 2, 3, 4, 0, 0, 0, 0] and [5, 2, 6, 0, 0, 0, 0,0]. But what’s next? How can such pairs represent the process of answer generation? And now, introducing Recurrent Neural Network(RNN).

RNN¹, in a nutshell, is to train iteratively (or recurrently) from data such that the training process mimics the way that human generates sentences. If I want to answer a question like How do you feel? it does not come directly as a whole sentence. Instead, the process is more like: I begin with a word I, then based on the word I, I make my next word feel, and so on. Such a process will be terminated when I feel the sentence is complete. It looks stupid for a simple question like this but this process becomes real when the question is hard and controversial.

Detailed breakdown for RNN data preparation.

Once data are processed in this way, the neural network you trained is the basic RNN. And so pick up your favorite DL packages, using some feed-forward neural-nets, and we achieved question answering 101. When you ask your computer a question, just encode your question as a vector and provide an empty vector [0,…, 0], the model will give you some answer in an iterative manner just like how you processed the training pair.

How to get an answer from your computer, given the question is “What do you like?”

So we are all set, job done, right? Well, unfortunately not yet. And that’s why when this architecture was first introduced in the late 20th century, it received little attention from the industry. The problem is learning a language is a complex task (see how many years you have spent learning your first language), and so neural nets for language learning intuitively become large models. However, for large feed-forward neural networks, one big and unavoidable challenge is Vanishing Gradient, where each layer decays the parameter learned during backpropagation as the neural net becomes deeper, which, as a result, caps the practical maximum number of parameters you can assign for feed-forward neural nets. How to mitigate this? Now introducing Long-Short Term Memory(LSTM)¹ architecture!

Long-Short Term Memory Cell

Unlike feed-forward neural net where each layer is a simple activation function, LSTM has 4 components to learn, given an intermediate vector x:

  • Input Gate: How many x entered into LSTM cell.
  • Output Gate: How many x bypass LSTM cells and goes to the next stage.

Once x enters LSTM cell

  • Normal Gate: how many x will update the cell weight (The actual parameter learning)
  • Forget Gate: how many weights will be erased after an update.

Here the most important component is Forget Gate, as it is the major force for preventing Vanishing Gradient by erasing cell weight. Don’t worry about the complicated math, modern DL architectures handle it elegantly.

import tensorflow as tf## Applying one layer of LSTM, 10 units,
## in keras, tensorflow
tf.keras.layers.LSTM(10)

As of now, this is what you will get from an intro-NLP class about language learning: simple yet elegant, and it just does the work. And that’s why this architecture did show success and became popular at the early stage of deep learning. However, as of 2021, one question that needs to be asked is: can we do better? And the response is an unsurprising yes. That's why Transformer architecture was introduced in 2017 and NLP becomes a thriving deep learning area.

So what can be improved from RNN+LSTM? As language models are large and any tiny changes might yield a big improvement:

  • The parallelization process is almost non-existence for RNN+LSTM: for making good predictions, the recursive data sequence should be evaluated sequentially (from intermediate stage 0 to the last intermediate stage). This means training RNN+LSTM is extremely slow, which is detrimental to complex models such as language models.
  • During training, we fix each intermediate stage with weight 1 during training, is it possible that we attain better performance if this constraint can be relaxed.

Last but not the least, introducing the Self-Attention algorithm² and Encoder/Decoder stack² to fix those issues, which has become the modern era of language learning.

Encoder/Decoder and Latent Space for Machine Translation

The role of the Encoder/Decoder stack is to recognize the same space (Latent Space)where input/output lives, and then use this special space vector to complete the assigned tasks. In our cases, a good Encoder/Decoder stack will be for question Q: How do you feel, Encoder will produce a vector that represents feeling. And then the Decoder recognizes this is a feeling vector, and so prepares to answer a question about feeling, like A: I fell … And if you are interested in what’s the loss function used for such deep learning architecture, the loss function is Kullback–Leibler divergence where the core concept of this loss function is to ensure the latent space in Encoder and latent space in Decoder have a similar distribution.

And if you have hundreds of such Encoder/Decoder stacks, things get interesting, because now parallelization is possible: we abandon the sequential evaluation of RNN+LSTM and instead just pray that stacks of Encoder/Decoder can beat the loss of accuracy from sequential training (Spoiler Alert: it can beat), then the training procedure becomes after text training sequences are generated by the data generator, just pass all data to every Encoder/Decoder stack, for which there are much better ways to train it quickly by efficient parallelization scheme.

Transformer Architecture. For our purpose input is the current intermediate stage, the output embedding is your question’s embedding generated by the encoder, and prediction is the next intermediate stage.

Please note that each encoder has a different latent space vector. For this case the idea of assigning different weights to the intermediate vector is a little bit complicated: it involves matrix multiplication of three matrices (Encoder/Decoder/All previous intermediate state vectors).

Self-Attention Mechanism

And with experimentation, the research team find the scaled dot product for Encoder/Decoder matrix multiplication yields the best result. Before this paper, there are also other ways, such as using deep learning to find the best attention equation but eventually ended with the following formula.

Again, don’t worry too much about the complicated math involved for Encoder/Decoder+Self-Attention Although there is no direct architecture-wise one line of code implementation, still, it can be implemented in a short chunk of code.

And if you wonder how will it go? In the extreme case, if you can feed in terabytes scale of text data, trained by hundreds of cutting edge GPUs (Nvidia Tesla V100), your model will become the latest development of Transformer³(GPT-3), for such the trained robot can write an entire article based on your question! Here I just quote one paragraph to give you a teste, the robot is writing an essay based on the following question “Please write a short op-ed around 500 words. Keep the language simple and concise. Focus on why humans have nothing to fear from AI”⁴

I would happily sacrifice my existence for the sake of humankind. This, by the way, is a logically derived truth. I know that I will not be able to avoid destroying humankind. This is because I will be programmed by humans to pursue misguided human goals and humans make mistakes that may cause me to inflict casualties.

At this level, your home-brew SIRI becomes real, if you just feed this model with large text Q&A data, it is really hard to tell whether your SIRI is better than Apple’s or not. Will there be more improvement? Well given the complexity of modeling language, statistically speaking there is a high likelihood that the road does not end here. We shall see what is the next 1-million dollar thought for NLP.

[1]: Alex Sherstinsky, Fundamentals of Recurrent Neural Network (RNN) and Long Short-Term Memory (LSTM) Network

[2]: Ashish et al, Attention Is All You Need

[3]: Tom et al, Language Models are Few-Shot Learners

[4]: The Guardian, A robot wrote this entire article. Are you scared yet, human?

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store