Two-Stage Synthesis Networks for Transfer Learning in Machine Comprehension
28 Nov 2017Introduction
- The paper proposes a two-stage synthesis network that can perform transfer learning for the task of machine comprehension.
-
The problem is the following:
-
We have a domain DS for which we have labelled dataset of question-answer pairs and another domain DT for which we do not have any labelled dataset.
-
We use the data for domain DS to train SynNet and use that to generate synthetic question-answer pairs for domain DT.
-
Now we can train a machine comprehension model M on DS and finetune using the synthetic data for DT.
-
- Link to the paper
SynNet
-
Works in two stages:
- Answer Synthesis - Given a text paragraph, generate an answer.
- Question Synthesis - Given a text paragraph and an answer, generate a question.
Answer Synthesis Network
- Given the labelled dataset for DS, generate a labelled dataset of <word, tag> pair such that each word in the given paragraph is assigned one of the 4 tags:
- IOBstart - if it is the starting word of an answer
- IOBmid - if it is the intermediate word of an answer
- IOBend - if it is the ending word of an answer
- IOBnone - if it is not part of any answer
-
For training, map the words to their GloVe embeddings and pass through a Bi-LSTM. Next, pass them through two-FC layers followed by a softmax layer.
- For the target domain DT, all the consecutive word spans where no label is IOBnone are returned as candidate answers.
Question Synthesis Network
-
Given an input paragraph and a candidate answer, Question Synthesis network generates question one word at a time.
-
Map each word in the paragraph to their GloVe embedding. After the word vector, append a ‘1’ if the word was part of the candidate answer else append a ‘0’.
-
Feed to a Bi-LSTM network (encoder-decoder) where the decoder conditions on the representation generated by the encoder as well as the question tokens generated so far. Decoding is stopped when “END” token is produced.
-
The paragraph may contain some named entities or rare words which do not appear in the softmax vocabulary. To account for such words, a copying mechanism is also incorporated.
-
At each time step, a Pointer Network (CP) and a Vocabulary Predictor (VP) are used to generate probability distribution for the next word and a Latent Predictor Network is used to decide which of the two networks would be used for the prediction.
-
At inference time, a greedy decoding is used where the most likely predictor is chosen and then the most likely word from that predictor is chosen.
Machine Comprehension Model
- Given any MC model, first train it over domain DS and then fine-tune using the artificial questions generated using DT.
Implementation Details
-
Data Regularization - There is a need to alternate between mini batches from source and target domain while fine-tuning the MC model.
-
At inference time, the fine-tuned MC model is used to get the distribution P(i=start) and P(i=end) (corresponding to the likelihood of choosing word I as the starting or ending word for the answer) for all the words and DP is used to find the optimal answer span.
-
Checkpoint Averaging - Use the different checkpointed models to average the answer likelihood before running DP.
-
Using the synthetically generated dataset helps to gain a 2% improvement in terms of F-score (from SQuAD -> NewsQA). Using checkpointed models further improves the performance to overall 46.6% F score which closes the gap with respect to the performance of model trained on NewsQA itself (~52.3% F score)