Continual learning with hypernetworks
08 Feb 2021Introduction

The paper proposes the use of taskconditioned HyperNetworks for lifelong learning / continual learning setups.

The idea is, the HyperNetwork would only need to remember the taskconditioned weights and not the inputoutput mapping for all the data points.
Terminology

$f$ denotes the network for the given $t^{th}$ task.

$h$ denotes the HyperNetwork that generates the weights for $f$.

$\Theta_{h}$ denotes the parameters of $h$.

$e^{t}$ denotes the input taskembedding for the $t^{th}$ task.
Approach

When training on the $t^{th}$ task, the HyperNetworks generates the weights for the network $f$.

The current task loss is computed using the generated weights, and the candidate weight update ($\Delta \Theta_{h}$) is computed for $h$.

The actual parameter change is computed by the following expression:
$L_{total} = L{task}(\Theta_{h}, e^{T}, X^{T}, Y^{T}) + \frac{\beta_{output}}{T1} \sum_{t=1}^{T1}  f_{h}(e^{t}, \Theta_{h}^*)  f_{h}(e^{(t)}, \Theta_{h} + \Delta \Theta_{h} ))^2$

$L_{task}$ is the loss for the current task.

$(X^{T}, Y^{T})$ denotes the training datapoints for the $T^{th}$ task.

$\beta_{output}$ is a hyperparameter to control the regularizerâ€™s strength.

$\Theta_{h}^*$ denotes the optimal parameters after training on the $T1$ tasks.

$\Theta_{h} + \Delta \Theta_{h}$ denotes the onestep update on the current $h$ model.

In practice, the task encoding $e^{t}$ is chunked into smaller vectors, and these vectors are fed as input to the HyperNetwork.

This enables the HyperNetwork to produce weights iteratively, instead of all at once, thus helping to scale to larger models.

The paper also considers the problem of inferring the task embedding from a given input pattern.

Specifically, the paper uses taskdependent uncertainty, where the task embedding with the least predictive uncertainty is chosen as the task embedding for the given unknown task. This approach is referred to as HNET+ENT.

The paper also considers using HyperNetworks to learn the weights for a taskspecific generative model. This generative model will be used to generate pseudo samples for rehearsalbased approaches. The paper considers two cases:

HNET+R where the replay model (i.e., the generative model) is parameterized using a HyperNetwork.

HNET+TIR, where an auxiliary task inference classifier is used to predict the task identity.

Experiments

Three setups are considered

CL1  Task identity is given to the model.

CL2  Task identity is not given, but taskspecific heads are used.

CL3  Task identity needs to be explicitly inferred.


On the permuted MNIST task, the proposed approach outperforms baselines like Synaptic Intelligence and Online EWC, and the performance gap is more significant for larger task sequences.

Forward knowledge transfer is observed with the CIFAR datasets.

One potential limitation (which is more of a limitation of HyperNetworks) is that HyperNetworks may be harder to scale for larger models like ResNet50 or transformers, thus limiting their usefulness for lifelong learning use cases.