When Recurrent Models Don’t Need To Be Recurrent
04 Oct 2018Introduction
-
The paper explores “if a well behaved RNN can be replaced by a feed-forward network of comparable size without loss in performance.”
-
“Well behaved” is defined in terms of control-theoretic notion of stability. This roughly requires that the gradients do not explode over time.
-
The paper shows that under the stability assumption, feedforward networks can approximate RNNs for both training and inference. The results are empirically validated as well.
Problem Setting
-
Consider a general, non linear dynamical system given by a differential state transition map Φw. The hidden ht = Φw(ht-1, xt).
-
Assumptions:
- Φ is smooth in w and h.
- h0 = 0
- Φw(0, 0) = 0 (can be ensured by translation)
-
Stable models are the ones where Φ is contractive ie Φw(h, x) - Φw(h’, x) is less than Λ * (h - h’)
-
For example, in RNN, stability would require that norm(w) is less than (Lp)-1 where Lp is the Lipschitz constant of the point-wise non linearity used.
-
The feedforward approximation uses a finite context (of length k) and is a truncated model.
-
A non-parametric function f maps the output of the recurrent model to prediction. If f is desired to be a parametric model, its parameters can be pushed to the recurrent model.
Theoretical Results
-
For a Λ-contractive system, it can be proved that for a large k (and additional Lipschitz assumptions) the difference in prediction between the recurrent and truncated mode is negligible.
-
If the recurrent model and truncated feed-forward network are initialized at the same point and trained over the same input for N-step, then for an optimal k, the weights of the two models would be very close in the Euclidean space. It can be shown that this small difference does not lead to large gradient differences during subsequent update steps.
-
This can be roughly interpreted as - if the gradient descent can train a stable recurrent network, it can also train a feedforward model and vice-versa.
-
The stability condition is important as, without that, truncated models would be bad (even for large values of k). Further, it is difficult to show that gradient descent converges to a stationary point.