4 min read

Differentiable Memory

Introduction

Before deep learning was a thing, if you had \(N\) vectors, and you wanted to match the closest one to some noisy version you would use kNN (\(k\) nearest neighbour). This implies comparing the input to all candidates, and usually weighting the final answer by the distance \(d_i\) to the top \(k\) matches. If \(k=N\) and the weighting has the form \(exp(-d / \rho)\), this becomes the Nadaraya–Watson kernel regressor, which turns out to have the same form as the venerable “attention” mechanism from deep learning.

In this post I use “attention” as a standalone component to create a differentiable lossy key-value store sometimes termed “differentiable memory”. I’ll start by explaining what it is from the perspective of linear algebra rather than neural networks, and then put together a little prototype in Jax/Optax to demonstrate how it all hangs together.

Differentiable Memory

Let \(k \in \mathbb{R}^m\) be keys and \(v \in \mathbb{R}^n\) be values. We have some set of key-value pairs \(\{(x,y)\}^N\) which we’d like to approximately compress given some fixed storage budget. Let \(K \in \mathbb{R}^{\bar{N} \times m}\), \(V \in \mathbb{R}^{\bar{N} \times n}\) represent the key-value prototypes where \(\bar{N} < N\) means compression. Given some key query \(q\) the prediction is given by \(\bar{v}(q) = V^\top softmax(Kq)\), Note that \(Kq\) is a dot product similarity. The objective is to minimise \(|| v- \bar{v}(q) ||^2\) over all pairs \(\{(k,v)\}^N\).

Let \(Q \in \mathbb{R}^{N \times m}\) and \(X \in \mathbb{R}^{N \times n}\) be the stacked keys and values which we intend to fit. We can then write the weights as \(W = softmax_{row}(QK^{\top})\) and \(\bar{V} = WV\). Hence, what we need is \(X \approx WV\).

Given \(K\), \(W\) is fixed and we end up with a standard closed form linear optimisation. That is, \(X = WV\), then \(W^\top X = W^\top W V\), and finally \(V=(W^\top W)^{-1}W^\top X\). We can regularise this in the usual way by adding a ridge penalty such that \(V=(W^\top W + \lambda I)^{-1}W^\top X\). Meanwhile, \(K\) participates in softmax so it is non-linear and it likely cannot be solved in closed form, so we need to use an iterative method, however at each step of optimising \(K\) we can calculate an optimal \(V\).

Once we’ve fitted \((K,V)\) we end up with \(N\) learned prototype values, gated by the dot product softmax over \(N\) learned keys: a powerful model.

Fitting

Lets write down the optimisation with jax and optax. I’ll use Adam to update \(K\) and solve for \(V\).

import jax, jax.numpy as jnp
import optax

def loss(K, V, Q, X):
    A = jax.nn.softmax(Q@K.T, axis=1)
    return 0.5 * jnp.sum((X-A@V)**2)

def up_V(K, Q, X, lam=1e-6):
    A = jax.nn.softmax(Q@K.T, axis=1)
    M = A.shape[1]
    return jnp.linalg.solve(A.T@A + lam*jnp.eye(M), A.T@X)

grad_K = jax.grad(loss, 0)

@jax.jit(static_argnums=3)
def step(K, V, state, opt, Q, X):
    V          = up_V(K, Q, X)
    ups, state = opt.update(grad_K(K,V,Q,X), state, K)
    return optax.apply_updates(K, ups), V, state

def fit(Q, X, N=1000, iters=1000, lr=0.001):
    m, n  = Q.shape[1], X.shape[1]
    K     = 0.01*jax.random.normal(jax.random.PRNGKey(0),(N,m))
    V     = jnp.zeros((N,n))
    opt   = optax.adam(lr)
    state = opt.init(K)
    for _ in range(iters):
        K, V, state = step(K, V, state, opt, Q, X)        
    return (lambda Q_: jax.nn.softmax(Q_@K.T, axis=1) @ V),K,V

Lets fit something to try it out.

An example

Fashion MNIST are little 784 pixel patches of gray scale fashion items. Lets randomly sample 78 pixel locations from each image to use as a key, and try recall the rest. Here I train on 10k examples and test on 60k, since my battered old laptop couldn’t handle more than that. The number of iterations and learning rates are held fixed; important is just to see that loss goes down as memory capacity (\(N\)) goes up.

from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt

SAMP = 10000
idx  = jax.random.choice(
    jax.random.PRNGKey(1), 784, (78,), replace=False)
X, y = fetch_openml(
    "Fashion-MNIST", version=1, return_X_y=True, as_frame=False)
X    = jnp.array(X, dtype=jnp.float32) / 255.0
P    = jax.random.permutation(jax.random.PRNGKey(0), X.shape[0])
X    = X[P]
Q    = X[:, idx]
Ns   = [2, 4, 8, 16, 32, 64, 128, 256]
errs = []

for N in Ns:
    f,_,_ = fit(Q[:SAMP], X[:SAMP], N=N, iters=1000, lr=0.0005)
    err   = jnp.sqrt(jnp.mean((X[SAMP:]-f(Q[SAMP:]))**2))
    errs.append(float(err))
    print(f"N={N:4d}  rmse={err:.6f}")

plt.plot(Ns, errs, marker="o")
plt.xlabel("N")
plt.ylabel("MSE")
plt.grid(True)
plt.show()

The output:

>> N=   2 rmse=0.248782
>> N=   4 rmse=0.204451
>> N=   8 rmse=0.173010
>> N=  16 rmse=0.153833
>> N=  32 rmse=0.139799
>> N=  64 rmse=0.128507
>> N= 128 rmse=0.121859
>> N= 256 rmse=0.119588

More parameters, more memory; as expected.

Why is it useful

Differentiable memory is able to learn non-linear structure connecting keys to values and so represent complicated relationships. It is smooth and continuous, making it useful as a general interpolation device. One standalone use case is in surrogate modelling, where differentiable memory can be used to compress and approximate the responses of a more complex model. Another use case, if the value vectors are categorical, is to use differentiable memory as a Kernel density estimator, mapping a complex key space into a categorical range, and allowing the inevitable overlaps to express category densities.