# Contrastive Learning of Structured World Models

## Introduction

• The paper introduces Contrastively-trained Structured World Models (C-SWMs).

• These models use a contrastive approach for learning representations in environments with compositional structure.

## Approach

• The training data is in the form of an experience buffer $B = \{(s_t, a_t, s_{t+1})\}_{t=1}^T$ of state transition tuples.

• The goal is to learn:

• an encoder $E$ that maps the observed states $s_t$ (pixel state observations) to latent state $z_t$.

• a transition model $T$ that predicts the dynamics in the hidden state.

• The model defines the enegry of a tuple $(s_t, a_t, s_{t+1})$ as $H = d(z_t + T(z_t, a_t), z_{t+1})$.

• The model has an inductive bias for modeling the effect of action as translation in the abstract state space.

• An extra hinge-loss term is added: $max(0, \gamma - d(z^{~}_{t}, z_{t+1}))$ where $z^{~}_{t} = E(s^{~}_{t})$ is a corrputed latent state corresponding to a randomly sampled state $s^{~}_{t}$.

## Object-Oriented State Factorization

• The goal is to learn object-oriented representations where each state embedding is structured as a set of objects.

• Assuming the number of object slots to be $K$, the latent space, and the action space can be factored into $K$ independent latent spaces ($Z_1 \times ... \times Z_K$) and action spaces ($A_1 \times ... \times A_k$) respectively.

• There are K CNN-based object extractors and an MLP-based object encoder.

• The actions are represented as one-hot vectors.

• A fully connected graph is induced over K objects (representations) and the transition function is modeled as a Graph Neural Network (GNN) over this graph.

• The transition function produces the change in the latent state representation of each object.

• The factorization can be taken into account in the loss function by summing over the loss corresponding to each object.

## Environments

• Grid World Environments - 2D shapes, 3D blocks

• Atari games - Pong and Space Invaders

• 3-body physics simulation

## Setup

• Random policy is used to collect the training data.

• Evaluation is performed in the latent space (no reconstruction in the pixel space) using ranking metrics. The observations (to compare against) are randomly sampled from the buffer.

• Baselines - auto-encoder based World Models and Physics as Inverse Graphics model.

## Results

• In the grid-world environments, C-SWM models the latent dynamics almost perfectly.

• Removing either the state factorization or the GNN transition model hurts the performance.

• C-SWM performs well on Atari as well but the results tend to have high variance.

• The optimal values of $K$ should be obtained by hyperparameter tuning.

• For the 3-body physics tasks, both the baselines and proposed models work quite well.

• Interestingly, the paper has a section on limitations:

• The object extractor module can not disambiguate between multiple instances of the same object (in a scene).

• The current formulation of C-SWM can only be used with deterministic environments.