Expected: Serial[
Serial[
Serial[
ShiftRight(1)
]
Embedding_33000_512
Dropout
Serial[
PositionalEncoding
]
Dup_out2
ReversibleSerial_in2_out2[
ReversibleHalfResidualDecoderAttn_in2_out2[
Serial[
LayerNorm
]
SelfAttention
]
ReversibleSwap_in2_out2
ReversibleHalfResidualDecoderFF_in2_out2[
Serial[
LayerNorm
Dense_2048
Dropout
Serial[
FastGelu
]
Dense_512
Dropout
]
]
ReversibleSwap_in2_out2
ReversibleHalfResidualDecoderAttn_in2_out2[
Serial[
LayerNorm
]
SelfAttention
]
ReversibleSwap_in2_out2
ReversibleHalfResidualDecoderFF_in2_out2[
Serial[
LayerNorm
Dense_2048
Dropout
Serial[
FastGelu
]
Dense_512
Dropout
]
]
ReversibleSwap_in2_out2
]
Concatenate_in2
LayerNorm
Dropout
Serial[
Dense_33000
]
]
LogSoftmax
].
Instructions: Implement the training_loop
below to train the neural network above. Here is a list of things you should do:
TrainTask
and EvalTask
trax.supervised.training.Loop
labeled_data=train_gen
loss_layer=tl.CrossEntropyLoss()
optimizer=trax.optimizers.Adam(0.01)
lr_schedule=lr_schedule
n_steps_per_checkpoint=10
You will be using your CrossEntropyLoss loss function with Adam optimizer. Please read the trax documentation to get a full understanding.
labeled_data=eval_gen
metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]