The input encoder runs on the input tokens, creates its embeddings, and feeds it to an LSTM network. This outputs the activations that will be the keys and values for attention
def input_encoder_fn(input_vocab_size, d_model, n_encoder_layers):
""" Input encoder runs on the input sentence and creates
activations that will be the keys and values for attention.
Args:
input_vocab_size: int: vocab size of the input
d_model: int: depth of embedding (n_units in the LSTM cell)
n_encoder_layers: int: number of LSTM layers in the encoder
Returns:
tl.Serial: The input encoder
"""
# create a serial network
input_encoder = tl.Serial(
### START CODE HERE ###
# create an embedding layer to convert tokens to vectors
tl.Embedding(xxx),
# feed the embeddings to the LSTM layers. It is a stack of n_encoder_layers LSTM layers
[tl.LSTM(xxxl) for i in range(xxx)]
### END CODE HERE ###
)
return input_encoder
The pre-attention decoder runs on the targets and creates activations that are used as queries in attention. This is a Serial network which is composed of the following:
tl.ShiftRight: This pads a token to the beginning of your target tokens (e.g. [8, 34, 12] shifted right is [0, 8, 34, 12]). This will act like a start-of-sentence token that will be the first input to the decoder. During training, this shift also allows the target tokens to be passed as input to do teacher forcing.
tl.Embedding: Like in the previous function, this converts each token to its vector representation. In this case, it is the the size of the vocabulary by the dimension of the model: tl.Embedding(vocab_size, d_model). vocab_size is the number of entries in the given vocabulary. d_model is the number of elements in the word embedding.
tl.LSTM: LSTM layer of size d_model.