Continual learning with hypernetworks
08 Feb 2021Introduction
-
The paper proposes the use of task-conditioned HyperNetworks for lifelong learning / continual learning setups.
-
The idea is, the HyperNetwork would only need to remember the task-conditioned weights and not the input-output mapping for all the data points.
Terminology
-
$f$ denotes the network for the given $t^{th}$ task.
-
$h$ denotes the HyperNetwork that generates the weights for $f$.
-
$\Theta_{h}$ denotes the parameters of $h$.
-
$e^{t}$ denotes the input task-embedding for the $t^{th}$ task.
Approach
-
When training on the $t^{th}$ task, the HyperNetworks generates the weights for the network $f$.
-
The current task loss is computed using the generated weights, and the candidate weight update ($\Delta \Theta_{h}$) is computed for $h$.
-
The actual parameter change is computed by the following expression:
$L_{total} = L{task}(\Theta_{h}, e^{T}, X^{T}, Y^{T}) + \frac{\beta_{output}}{T-1} \sum_{t=1}^{T-1} | f_{h}(e^{t}, \Theta_{h}^*) - f_{h}(e^{(t)}, \Theta_{h} + \Delta \Theta_{h} ))|^2$
-
$L_{task}$ is the loss for the current task.
-
$(X^{T}, Y^{T})$ denotes the training datapoints for the $T^{th}$ task.
-
$\beta_{output}$ is a hyperparameter to control the regularizer’s strength.
-
$\Theta_{h}^*$ denotes the optimal parameters after training on the $T-1$ tasks.
-
$\Theta_{h} + \Delta \Theta_{h}$ denotes the one-step update on the current $h$ model.
-
In practice, the task encoding $e^{t}$ is chunked into smaller vectors, and these vectors are fed as input to the HyperNetwork.
-
This enables the HyperNetwork to produce weights iteratively, instead of all at once, thus helping to scale to larger models.
-
The paper also considers the problem of inferring the task embedding from a given input pattern.
-
Specifically, the paper uses task-dependent uncertainty, where the task embedding with the least predictive uncertainty is chosen as the task embedding for the given unknown task. This approach is referred to as HNET+ENT.
-
The paper also considers using HyperNetworks to learn the weights for a task-specific generative model. This generative model will be used to generate pseudo samples for rehearsal-based approaches. The paper considers two cases:
-
HNET+R where the replay model (i.e., the generative model) is parameterized using a HyperNetwork.
-
HNET+TIR, where an auxiliary task inference classifier is used to predict the task identity.
-
Experiments
-
Three setups are considered
-
CL1 - Task identity is given to the model.
-
CL2 - Task identity is not given, but task-specific heads are used.
-
CL3 - Task identity needs to be explicitly inferred.
-
-
On the permuted MNIST task, the proposed approach outperforms baselines like Synaptic Intelligence and Online EWC, and the performance gap is more significant for larger task sequences.
-
Forward knowledge transfer is observed with the CIFAR datasets.
-
One potential limitation (which is more of a limitation of HyperNetworks) is that HyperNetworks may be harder to scale for larger models like ResNet50 or transformers, thus limiting their usefulness for lifelong learning use cases.