4 min read

Metric learning with linear methods

I read this paper a while ago, which sets out the problem of linear metric learning nicely. I wanted to see whether metric learning was possible to carry out in closed form. It turned out to be relatively straightforward.

Say we have some feature vectors \(x_i \in \mathbb{R}^p\) and some responses \(y_i \in \mathbb{R}^k\), We want:

\[(Ax_i - Ax_j)^\top (Ax_i - Ax_j) \approx (y_i -y_j)^\top (y_i -y_j)\]

where \(A\) is a \(p \times p\) matrix. That is, let’s find a transformation \(A\) which makes the squared Euclidean distance between \(i\) and \(j\) as close as possible, whether calculated in terms of \(x\) or \(y\). Call this act of changing our feature vectors to better match our responses, “metric learning”. It is interesting because many algorithms implicitly rely on similarity between feature vectors (e.g. Kernel methods, kNN), so they can be improved by making the distance metric more relevant to the problem at hand.

We can rewrite the left hand side as follows:

\[(Ax_i - Ax_j)^\top (Ax_i - Ax_j) = (x_i-x_j)^\top A^\top A(x_i-x_j)\]

This allows us to substitute \(B = A^\top A\), and then solve for \(B\) in a problem of the form \(v_k^\top B v_k = \bar{y}_k\), where \(v_k = x_i - x_j\) and \(\bar{y_k} = (y_i -y_j)^\top (y_i -y_j)\). My initial thought was to use the Kronecker product identity \(vec(y) = vec(v^\top B v) = vec(v^\top \otimes v^\top)vec(B)\) and then solve for \(vec(B)\) which could then be reshaped to \(B\). However, \(B\) must be a symmetric matrix and the optimisation above would not respect this constraint in the general case.

Instead, consider if \(B\) was a \(2\times 2\) matrix. If we were to manually write out every row – with parameters only for the upper triangle due to symmetry – we’d get \(b_{1,1}v_1^2 + b_{2,2}v_2^2 + 2 b_{1,2}v_1 v_2\) which equates to \([v_1^2, v_2^2, 2v_1 v_2] \times [b_{1,1}, b_{2,2}, b_{1,2}]^\top\). This generalises to any dimension of square matrix \(B\). Then, stacking \(v_k\) rows to form \(V\), we can write \(Vb = \bar{y}\). We can solve for the vector \(b\) in the usual way, and then reshape \(b\) into \(B\). The procedure will ensure that \(B\) is symmetric.

Recalling that \(B = A^\top A\) – a quadratic form – implies that \(B\) should be positive and at least semi-definite (definite if \(B\) is full rank). However, the OLS solution for \(b\) is only constrained to be symmetric and cannot guarantee a quadratic form otherwise. So, we need to project \(B\) to the closest positive semi-definite matrix, which we can do using the eigen-decomposition \(B = Q\Delta Q^\top\) (where \(\Delta\) is a diagonal matrix of eigenvalues) by setting the negative eigenvalues in \(\Delta\) to zero. Finally, we can arrive at \(A\) by taking the matrix square root to get \(A = \sqrt{\Delta} Q^\top\).

Lets test it in R. First I’ll generate some data, and transform \(X\) using a random \(\bar{A}\) matrix to form the matrix \(Y\).

P  <- 10
N  <- 100
X  <- matrix(rnorm(P*N, mean=1:P, sd=sqrt(1:P)), nrow=N)
A_ <- matrix(rnorm(P*P, mean=0, sd=10), nrow=P)
Y  <- X %*% A_

Next, I need the indexes of some sample pairs which will be used to calculate squared error: the subject of the optimisation.

I <- matrix(sample(1:N, 5000, replace=T), ncol=2)
I <- I[I[,1] > I[,2],]
J <- I[,2]
I <- I[,1]
y <- rowSums((Y[I,] - Y[J,])^2)

I fit the pairs as described above and calculate the transformation \(A\):

idx <- expand.grid(1:P, 1:P)
idx <- rbind(idx[idx$Var1 == idx$Var2,],
             idx[idx$Var1 < idx$Var2,])
scl <- c(rep(1,P), rep(2, NROW(idx)-P))
m   <- NROW(idx)
V   <- apply(X[I,]-X[J,], 1,
             \(x) scl*x[idx$Var1]*x[idx$Var2],
             simplify=F) %>% do.call(rbind, .)
a   <- solve(t(V) %*% V)  %*% t(V) %*% y
ATA <- matrix(ncol=P, nrow=P)

ATA[cbind(idx$Var1[1:m], idx$Var2[1:m])] <- a
ATA[cbind(idx$Var2[1:m], idx$Var1[1:m])] <- a

eig <- eigen(ATA)
Q   <- eig$vectors
E   <- ifelse(eig$values < 0, 0, eig$values)
A   <- diag(E^0.5) %*% t(Q) 

I compare the sum of squared differences in the original and transformed space:

y_ <- apply(X[I,]-X[J,], 1, \(x) sum((A %*% x)^2))

print(sum(y-y_)^2)
## [1] 5.847484e-12

As required, \(A\) represents an optimal transform.