# Set Transformer - A Framework for Attention-based Permutation-Invariant Neural Networks

## Introduction

• Consider problems where the input to the model is a set. In such problems (referred to as the set-input problems), the model should be invariant to the permutation of the data points.

• In “set pooling” methods (1, 2), each data point (in the input set) is encoded using a feed-forward network and the resulting set of encoded representations are pooled using the “sum” operator.

• This approach can be shown to be bot permutation-invariant and a universal function approximator.

• The paper proposes an attention-based network module, called as the Set Transformer, which can model the interactions between the elements of an input set while being permutation invariant.

## Transformer

• An attention function Attn(Q, K, V) = (QKT)V is used to map queries Q to output using key-value pairs K, V.

• In case of multi-head attention, the key, query, and value are projected into h different vectors and attention is applied on all these vectors. The output is a linear transformation of the concatenation of all the vectors.

## Set Transformer

• 3 modules are introduced: MAB, SAB and ISAB.

• Multihead Attention Block (MAB) is a module very similar to to the encoder in the Transformer, without the positional encoding and dropout.

• Set Attention Block (SAB) is a module that takes as input a set and performs self-attention between the elements of the set to produce another set of the same size ie SAB(X) = MAB(X, X).

• The time complexity of the SAB operation is O(n2) where n is the number of elements in the set. It can be reduced to O(m*n) by using Induced Set Attention Blocks (ISAB) with m induced point vectors (denoted as I).

• ISABm = MAB(X, MAB(I, X)).

• ISAB can be seen as performing a low-rank projection of inputs.

• These modules can be used to model the interactions between data points in any given set.

## Pooling by Multihead Attention (PMA)

• Aggregation is performed by applying multi-head attention on a set of k seed vectors.

• The interaction between the k outputs (from PMA) can be modeled by applying another SAB.

• Thus the entire network is a stack of SABs and ISABs. Both the modules are permutation invariant and so is any network obtained by stacking them.

## Experiments

• Datasets include:

• Predicting the maximum value from a set.
• Counting unique (Omniglot) characters from an image.
• Clustering with a mixture of Gaussians (synthetic points and CIFAR 100).
• Set Anomaly detection (celebA).
• Generally, increasing m (the number of inducing datapoints) improve performance, to some extent. This is somewhat expected.

• The paper considers various ablations of the proposed approach (like disabling attention in the encoder or pooling layer) and shows that attention mechanism is needed during both the stages.

• The work has two main benefits over prior work:

• Reducing the O(n2) complexity to O(m*n) complexity.

• Using self-attention mechanism for both encodings the inputs and for aggregating the encoded representations.