Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One
06 Feb 2020Introduction

The paper proposed a framework for joint modeling of labels and data by interpreting a discriminative classifier p(yx) as an energybased model p(x, y).

Joint modeling provides benefits like improved calibration (i.e., the predictive confidence should align with the miss classification rate), robustness, and out of order distribution.
Motivation

Consider a standard classifier $f_{\theta}(x)$ which produces a kdimensional vector of logits.

$p_{\theta}(y  x) = softmax(f_{\theta}(x)[y])$

Uisng concepts from energy based models, we write $p_{\theta}(x, y) = \frac{exp(E_{\theta}(x, y))}{Z_{\theta}}$ where $E_{\theta}(x, y) = f_{\theta}(x)[y]$

$p_{\theta}(x) = \sum_{y}{ \frac{exp(E_{\theta}(x, y))}{Z_{\theta}}}$

$E_{\theta}(x) = LogSumExp_y(f_{\theta}(x)[y])$

Note that in the standard discriminative setup, shiting the logits $f_{\theta}(x)$ does not affect the model but it affects $p_{\theta}(x)$.

Computing $p_{\theta}(y  x)$ using $p_{\theta}(x, y)$ and $p_{\theta}(x)$ gives back the same softmax parameterization as before.

This reinterpreted classifier is referred to as a Joint Energybased Model (JEM).
Optimization

The logliklihood of the data can be factoized as $log p_{\theta}(x, y) = log p_{\theta}(x) + log p_{\theta}(y  x)$.

The second factor can be trained using the standard CE loss. In contrast, the first factor can be trained using a sampler based on Stochastic Gradient Langevin Dynamics.
Results
Hybrid Modelling

Datasets: CIFAR10, CIFAR100, SVHN.

Metrics: Inception Score, Frechet Inception Distance

JEM outperforms generative, discriminative, and hybrid models on both generative and discriminative tasks.
Calibration

A calibrated classifier is the one where the predictive confidence aligns with the misclassification rate.

Dataset: CIFAR100

JEM improves calibration while retaining high accuracy.
Out of Distribution (OOD) Detection

One way to detect OOD samples is to learn a density model that assigns a higher likelihood to indistribution examples and lower likelihood to out of distribution examples.

JEM consistently assigns a higher likelihood to indistribution examples.

The paper also proposes an alternate metric called approximate mass to detect OOD examples.

The intuition is that a point could have likelihood but be impossible to sample because its surroundings have a very low density.

On the other hand, the indistribution data points would lie in a region of high probability mass.

Hence the norm of the gradient of log density could provide a useful signal to detect OOD examples.
Robustness
 JEM is more robust to adversarial attacks as compared to discriminative classifiers.