Created
October 9, 2019 12:31
-
-
Save mkocabas/54ea2ff3b03260e3fedf8ad22536f427 to your computer and use it in GitHub Desktop.
Pytorch batch procrustes implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
def compute_similarity_transform(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.T | |
S2 = S2.T | |
transposed = True | |
assert(S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=1, keepdims=True) | |
mu2 = S2.mean(axis=1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# 2. Compute variance of X1 used for scale. | |
var1 = np.sum(X1**2) | |
# 3. The outer product of X1 and X2. | |
K = X1.dot(X2.T) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, Vh = np.linalg.svd(K) | |
V = Vh.T | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = np.eye(U.shape[0]) | |
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) | |
# Construct R. | |
R = V.dot(Z.dot(U.T)) | |
# 5. Recover scale. | |
scale = np.trace(R.dot(K)) / var1 | |
# 6. Recover translation. | |
t = mu2 - scale*(R.dot(mu1)) | |
# 7. Error: | |
S1_hat = scale*R.dot(S1) + t | |
if transposed: | |
S1_hat = S1_hat.T | |
return S1_hat | |
def compute_similarity_transform_torch(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.T | |
S2 = S2.T | |
transposed = True | |
assert (S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=1, keepdims=True) | |
mu2 = S2.mean(axis=1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# print('X1', X1.shape) | |
# 2. Compute variance of X1 used for scale. | |
var1 = torch.sum(X1 ** 2) | |
# print('var', var1.shape) | |
# 3. The outer product of X1 and X2. | |
K = X1.mm(X2.T) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, V = torch.svd(K) | |
# V = Vh.T | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = torch.eye(U.shape[0], device=S1.device) | |
Z[-1, -1] *= torch.sign(torch.det(U @ V.T)) | |
# Construct R. | |
R = V.mm(Z.mm(U.T)) | |
# print('R', X1.shape) | |
# 5. Recover scale. | |
scale = torch.trace(R.mm(K)) / var1 | |
# print(R.shape, mu1.shape) | |
# 6. Recover translation. | |
t = mu2 - scale * (R.mm(mu1)) | |
# print(t.shape) | |
# 7. Error: | |
S1_hat = scale * R.mm(S1) + t | |
if transposed: | |
S1_hat = S1_hat.T | |
return S1_hat | |
def batch_compute_similarity_transform_torch(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.permute(0,2,1) | |
S2 = S2.permute(0,2,1) | |
transposed = True | |
assert(S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=-1, keepdims=True) | |
mu2 = S2.mean(axis=-1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# 2. Compute variance of X1 used for scale. | |
var1 = torch.sum(X1**2, dim=1).sum(dim=1) | |
# 3. The outer product of X1 and X2. | |
K = X1.bmm(X2.permute(0,2,1)) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, V = torch.svd(K) | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) | |
Z = Z.repeat(U.shape[0],1,1) | |
Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1)))) | |
# Construct R. | |
R = V.bmm(Z.bmm(U.permute(0,2,1))) | |
# 5. Recover scale. | |
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 | |
# 6. Recover translation. | |
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) | |
# 7. Error: | |
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t | |
if transposed: | |
S1_hat = S1_hat.permute(0,2,1) | |
return S1_hat |
There is an issue with batch_compute_similarity_transform_torch.
The check for the shapes of S1 and S2 should be done for the 1st dimension, and not 0th, i.e. it should be "if S1.shape[1] != 3 and S1.shape[1] != 2:". But then again, there would be a problem if the number of 2D/3D points happens to be 2 or 3, in which case the tensors wouldn't get permuted. This check should probably be removed altogether and the required dimensions should just be explained in the docstring.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
why restrict to 3?