Contrastive Learning of Structured World Models
28 Nov 2019Introduction

The paper introduces Contrastivelytrained Structured World Models (CSWMs).

These models use a contrastive approach for learning representations in environments with compositional structure.
Approach

The training data is in the form of an experience buffer \(B = \{(s_t, a_t, s_{t+1})\}_{t=1}^T\) of state transition tuples.

The goal is to learn:

an encoder \(E\) that maps the observed states $s_t$ (pixel state observations) to latent state $z_t$.

a transition model \(T\) that predicts the dynamics in the hidden state.


The model defines the enegry of a tuple \((s_t, a_t, s_{t+1})\) as \(H = d(z_t + T(z_t, a_t), z_{t+1})\).

The model has an inductive bias for modeling the effect of action as translation in the abstract state space.

An extra hingeloss term is added: \(max(0, \gamma  d(z^{~}_{t}, z_{t+1}))\) where \(z^{~}_{t} = E(s^{~}_{t})\) is a corrputed latent state corresponding to a randomly sampled state \(s^{~}_{t}\).
ObjectOriented State Factorization

The goal is to learn objectoriented representations where each state embedding is structured as a set of objects.

Assuming the number of object slots to be \(K\), the latent space, and the action space can be factored into \(K\) independent latent spaces (\(Z_1 \times ... \times Z_K\)) and action spaces (\(A_1 \times ... \times A_k\)) respectively.

There are K CNNbased object extractors and an MLPbased object encoder.

The actions are represented as onehot vectors.

A fully connected graph is induced over K objects (representations) and the transition function is modeled as a Graph Neural Network (GNN) over this graph.

The transition function produces the change in the latent state representation of each object.

The factorization can be taken into account in the loss function by summing over the loss corresponding to each object.
Environments

Grid World Environments  2D shapes, 3D blocks

Atari games  Pong and Space Invaders

3body physics simulation
Setup

Random policy is used to collect the training data.

Evaluation is performed in the latent space (no reconstruction in the pixel space) using ranking metrics. The observations (to compare against) are randomly sampled from the buffer.

Baselines  autoencoder based World Models and Physics as Inverse Graphics model.
Results

In the gridworld environments, CSWM models the latent dynamics almost perfectly.

Removing either the state factorization or the GNN transition model hurts the performance.

CSWM performs well on Atari as well but the results tend to have high variance.

The optimal values of $K$ should be obtained by hyperparameter tuning.

For the 3body physics tasks, both the baselines and proposed models work quite well.

Interestingly, the paper has a section on limitations:

The object extractor module can not disambiguate between multiple instances of the same object (in a scene).

The current formulation of CSWM can only be used with deterministic environments.
