An elementary example of soft attention
If you, like me, are fascinated by the new developments in soft/hard attention. But you struggle to understand the exact mechanism. This post explains exactly that.
In this post we implement the basic soft attention. We run softmax with a dataset derived from MNIST. Soft attention is a complicated subject. Fortunately, MNIST digits are easy to understand and a good introduction.
Soft attention originates from the captioning problem. In image captioning concerns with generating captions for images. Typical examples are shown below.
In such model, a CNN converts an image to a lower dimensional representation. This representation inputs into another network, typically a LSTM, which will generate natural language. In this pipeline, the LSTM receives representation of the image only once. We humans tend to look twice or more at an image while we describe the concent. The soft attention models exactly that. Every time-step, the LSTM outputs natural language to caption the image. For soft-attention, it also outputs a key, by which it attends to a part of the image. This information feeds in as input to the next time-step.
It might be hard to think how an LSTM takes a latent representation of an image to generate natural language. Therefore, this elementary model simply strips that part. Engineers typically struggle how to feed the representation, is it an input, a hidden state, a memory state ... We will not bother such complicated questions
The pipeline goes as follows. MNIST digits are handwritten digits 0,1,2 to 9. Typical convolutional neural networks manae to classify these above human accuracy. The activations of the first layer of such networks corresponds to an activation volume. After a convolutional layer and a max-pooling layer, the activation maps span 14 by 14 pixels. A layer typically consists of many neurons, in our case 12. The activation volume thus spans a Tensor of shape [14,14,12]
The LSTM outputs a key to attend to this activation volume. Researchers have been using three techniques to relate the key and the volume
The Github repo demonstrates all three approaches. Respectively attention_main_sm.py, attention_main_gaussian.py and attention_main.py.
For the final option, with the feature-like vector, the next section displays results.
Details on the training. The activations stem from the first convlayer to classify MNIST. Activation maps are [14,14,12]. At the moment of extracting these images, the LSTM classified 83% correctly and was still improving. I'm working on a five-year-old laptop, so training untill better performance takes long
This visualization shows the attention mask at every time step of the ten. The final two plots display the sum over the activation volume and the original
Note that the colorbar applies to only the first ten images. The final two images are auto-scaled. Red indicates high attention in the mask. Blue indicates low attention
As we display an elementary example, there's many room for improvements. Some of which are
As always, I am curious to any comments and questions. Reach me at romijndersrob@gmail.com