ELECTRA - Pre-training Text Encoders as Discriminators Rather Than Generators
20 Feb 2020Introduction
Masked Language Modeling (MLM) is a common technique for pre-training language-based models. The idea is to “corrupt” some tokens in the input text (around 15%) by replacing them with the [MASK] token and then training the network to reconstruct (or predict) the corrupted tokens.
Since the network learns from only about 15% of the tokens, the computational cost of training using MLM can be quite high.
The paper proposes to use a “replaced token detection” task where some tokens in the input text are replaced by other plausible tokens.
For each token in the modified text, the network has to predict if the token has been replaced or not.
The alternative token is generated using a small generator network.
Unlike the previous MLM setup, the proposed task is defined for all the input tokens, thus utilizing the training data more efficiently.
The proposed approach is called ELECTRA (Efficiently Learning an Encoder that Classifies Token Replacements Accurately)
Two neural networks - Generator (G) and Discriminator (D) are trained.
Each network has a Transformer-based text encoder that maps a sequence of words into a sequence of vectors.
Given an input sequence x (of length N), k indices are chosen for replacing the tokens.
For each index, the generator produces a distribution over tokens. A token is sampled to replace in the original sequence. The resulting sequence is referred to as the corrupted sequence.
Given the corrupted sequence, the Discriminator predicts which token comes from the data distribution and which comes from the generator.
The generator is trained using the MLM setup, and the Discriminator is trained using the discriminative loss.
After pre-training, only the Discriminator is finetuned on the downstream tasks.
GLUE Benchmark
Stanford QA dataset
Architecture Choices
Sharing word embeddings between generator and Discriminator helps.
Tying all the encoder weights leads to marginal improvement but forces the generator and the Discriminator to be of the same size. Hence only embeddings are shared.
Generator model is kept smaller than the discriminator model as a strong generator can make the training difficult for the Discriminator.
A two-stage training procedure was explored where only the generator is trained for n steps. Then the weights of the generator are used to initialize the Discriminator. The Discriminator is then trained for n steps while keeping the generator fixed.
This two-stage setup provides a nice curriculum for the Discriminator but does not outperform the joint training based setup.
An adversarial loss based setup is also explored but it does not work well probably because of the following reasons:
Adverserially trained generator is not as good as the MLM generator.
Adverserially trained generator produces a low entropy output distribution.
ELECTRA-15 is a variant of ELECTRA where the Discriminator is trained on only 15% of the tokens (similar to the MLM setup). This reduces performance significantly.
Replace MLM setup
Perform MLM training, but instead of using [MASK], use a toke sampled from the generator.
This improves the performance marginally.
All-token MLM
In the MLM setup, replace the [MASK] token by the sampled tokens and train the MLM model to generate all the words.
In practice, the MLM model can either generate a word or copy the existing word.
This approach closes much of the gap between BERT and ELECTRA.
Interestingly, ELECTRA outperforms All-token MLM BERT suggesting the ELECTRA may be benefitting from parameter efficiency since it does not have to learn a distribution over all the words.