Set Transformer  A Framework for Attentionbased PermutationInvariant Neural Networks
18 Jul 2019Introduction

Consider problems where the input to the model is a set. In such problems (referred to as the setinput problems), the model should be invariant to the permutation of the data points.

In “set pooling” methods (1, 2), each data point (in the input set) is encoded using a feedforward network and the resulting set of encoded representations are pooled using the “sum” operator.

This approach can be shown to be bot permutationinvariant and a universal function approximator.

The paper proposes an attentionbased network module, called as the Set Transformer, which can model the interactions between the elements of an input set while being permutation invariant.
Transformer

An attention function Attn(Q, K, V) = (QK^{T})V is used to map queries Q to output using keyvalue pairs K, V.

In case of multihead attention, the key, query, and value are projected into h different vectors and attention is applied on all these vectors. The output is a linear transformation of the concatenation of all the vectors.
Set Transformer

3 modules are introduced: MAB, SAB and ISAB.

Multihead Attention Block (MAB) is a module very similar to to the encoder in the Transformer, without the positional encoding and dropout.

Set Attention Block (SAB) is a module that takes as input a set and performs selfattention between the elements of the set to produce another set of the same size ie SAB(X) = MAB(X, X).

The time complexity of the SAB operation is O(n^{2}) where n is the number of elements in the set. It can be reduced to O(m*n) by using Induced Set Attention Blocks (ISAB) with m induced point vectors (denoted as I).

ISAB_{m} = MAB(X, MAB(I, X)).

ISAB can be seen as performing a lowrank projection of inputs.

These modules can be used to model the interactions between data points in any given set.
Pooling by Multihead Attention (PMA)

Aggregation is performed by applying multihead attention on a set of k seed vectors.

The interaction between the k outputs (from PMA) can be modeled by applying another SAB.

Thus the entire network is a stack of SABs and ISABs. Both the modules are permutation invariant and so is any network obtained by stacking them.
Experiments

Datasets include:
 Predicting the maximum value from a set.
 Counting unique (Omniglot) characters from an image.
 Clustering with a mixture of Gaussians (synthetic points and CIFAR 100).
 Set Anomaly detection (celebA).

Generally, increasing m (the number of inducing datapoints) improve performance, to some extent. This is somewhat expected.

The paper considers various ablations of the proposed approach (like disabling attention in the encoder or pooling layer) and shows that attention mechanism is needed during both the stages.

The work has two main benefits over prior work:

Reducing the O(n^{2}) complexity to O(m*n) complexity.

Using selfattention mechanism for both encodings the inputs and for aggregating the encoded representations.
