Skip to content

Instantly share code, notes, and snippets.

@roomrys
Created October 8, 2024 15:52
Show Gist options
  • Save roomrys/a053203a511764da871259dfd982fb34 to your computer and use it in GitHub Desktop.
Save roomrys/a053203a511764da871259dfd982fb34 to your computer and use it in GitHub Desktop.
SLEAP: Multiview Association via Pairs of Views and Fundamental Matrix
"""This module implements cycle consistent matching using pairs of views."""
from __future__ import annotations
from typing import Generator
import cv2
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.optimize import linear_sum_assignment
import sleap
from sleap.io.cameras import Camcorder, RecordingSession
def undistort_keypoints(keypoints, camera_matrix, distortion_coefficients):
"""Undistort keypoints using the camera matrix and distortion coefficients.
Args:
keypoints: The keypoints to undistort. Shape (n_points, 2).
camera_matrix: The camera matrix. Shape (3, 3).
distortion_coefficients: The distortion coefficients. Shape (5,).
Returns:
The undistorted keypoints. Shape (n_points, 2).
"""
if isinstance(keypoints, list):
keypoints = np.array(keypoints)
if isinstance(camera_matrix, list):
camera_matrix = np.array(camera_matrix)
if isinstance(distortion_coefficients, list):
distortion_coefficients = np.array(distortion_coefficients)
original_shape = keypoints.shape
reshaped_keypoints = keypoints.reshape(-1, 2)
undistorted_keypoints = cv2.undistortPoints(
reshaped_keypoints, camera_matrix, distortion_coefficients, P=camera_matrix
)
return undistorted_keypoints.reshape(original_shape)
def get_camera_pairs(camera_cluster, names) -> list[Camcorder]:
"""Generator to get pairs of cameras from the camera cluster.
Args:
camera_cluster: The camera cluster containing the cameras.
names: The names of the cameras to pair.
Yields:
A pair of cameras.
"""
cameras = []
for name in names:
try:
camera = next(find_camera_by_name(camera_cluster, name))
except StopIteration:
raise ValueError(
f"Camera {name} not found in the camera cluster: {camera_cluster.cameras}"
)
cameras.append(camera)
return cameras
def find_camera_by_name(camera_cluster, name) -> Generator[Camcorder]:
"""Generator to find a camera by its name in the camera cluster.
Args:
camera_cluster: The camera cluster containing the cameras.
camera_name: The name of the camera to find.
Yields:
The camera with the specified name.
"""
for camera in camera_cluster:
if camera.name == name:
yield camera
def inverse_of_upper_triangular(matrix):
"""Calculate the inverse of an upper triangular matrix.
Args:
matrix: The upper triangular matrix. Shape (n, n).
"""
if isinstance(matrix, list):
matrix = np.array(matrix)
n = matrix.shape[0]
inverse = np.zeros((n, n))
# Iterate over the diagonal elements
for i in range(n):
inverse[i, i] = 1 / matrix[i, i]
# Iterate over the off-diagonal elements
for i in range(n):
for j in range(i + 1, n):
sum_product = 0
for k in range(i, j):
sum_product += matrix[i, k] * inverse[k, j]
inverse[i, j] = -sum_product / matrix[j, j]
return inverse
def calculate_essential_matrix(rotation, translation):
"""Calculate the essential matrix from rotation and translation.
Args:
rotation: The rotation matrix between two cameras. Shape (3, 3).
translation: The translation vector between two cameras.
Returns:
The essential matrix. Shape (3, 3).
"""
t_cross_top = np.array(
[
[0, -translation[2], translation[1]],
[0, 0, -translation[0]],
[0, 0, 0],
]
)
t_cross = t_cross_top - t_cross_top.T
essential_matrix = t_cross @ rotation
return essential_matrix
def enforce_rotation_matrix(rot):
"""Enforce that the input matrix is a rotation matrix.
Args:
rot: The input matrix. Shape (3, 3) or (3, 1).
"""
if isinstance(rot, list):
rot = np.array(rot)
if rot.shape != (3, 3):
if rot.shape == (3,):
# Assuming axis-angle representation, convert to rotation matrix
rot = cv2.Rodrigues(rot)[0]
else:
raise ValueError(
"Rotation matrix must be of shape (3, 3). Recieved: ", rot.shape
)
return rot
def unrotate(points, rotation):
"""Unrotate points using the rotation matrix.
Mutiplying the rotation on the left to rotates the points in the opposite direction.
Args:
points: The points to unrotate. Shape (n_points, 3). points = rotation @ output.
rotation: The rotation matrix. Shape (3, 3).
Returns:
The unrotated points. Shape (n_points, 3). output = points @ rotation.
"""
rotation = enforce_rotation_matrix(rotation)
return points @ rotation
def calculate_fundamental_matrix(cam_matrices_1, cam_matrices_2):
"""Calculate the fundamental matrix between two cameras.
Args:
cam_matrices_1: The camera matrices for camera 1.
cam_matrices_2: The camera matrices for camera 2.
Returns:
The fundamental matrix F. Shape (3, 3). x_2.T @ F @ x_1 = 0.
"""
# Retrieve the rotation from world to camera
rot_1 = enforce_rotation_matrix(cam_matrices_1["rotation"])
# Retrieve the unrotated translation from world to camera
translation_1 = unrotate(cam_matrices_1["translation"], rot_1)
# Calculate the essential matrix
essential_matrix_1 = calculate_essential_matrix(rot_1, translation_1)
# Calculate the fundamental matrix
inverse_intrinsics_1 = inverse_of_upper_triangular(cam_matrices_1["matrix"])
inverse_intrinsics_2 = inverse_of_upper_triangular(cam_matrices_2["matrix"])
fundamental_matrix = (
inverse_intrinsics_2.T @ essential_matrix_1 @ inverse_intrinsics_1
)
return fundamental_matrix
def get_homogenous_keypoints(keypoints):
"""Convert 2D keypoints to homogenous coordinates.
Args:
keypoints: The keypoints. Shape (n_points, 2).
Returns:
The keypoints in homogenous coordinates. Shape (n_points, 3).
"""
if isinstance(keypoints, list):
keypoints = np.array(keypoints)
if keypoints.shape[-1] == 3:
print("Keypoints are already in homogenous coordinates.")
return keypoints
if len(keypoints.shape) == 3:
ones_shape = (keypoints.shape[:2]) + (1,)
stack = np.dstack
else:
# len(keypoints.shape) == 2
ones_shape = (keypoints.shape[0], 1)
stack = np.hstack
return stack((keypoints, np.ones(ones_shape)))
def undistort_and_homogenize_keypoints(
keypoints, camera_matrix, distortion_coefficients
):
"""Undistort and homogenize keypoints using the camera matrix and distortion coefficients.
Args:
keypoints: The keypoints to undistort. Shape (n_points, 2).
camera_matrix: The camera matrix. Shape (3, 3).
distortion_coefficients: The distortion coefficients. Shape (5,).
Returns:
The undistorted and homogenized keypoints. Shape (n_points, 3).
"""
if camera_matrix is None or distortion_coefficients is None:
undistorted_keypoints = keypoints
else:
undistorted_keypoints = undistort_keypoints(
keypoints, camera_matrix, distortion_coefficients
)
return get_homogenous_keypoints(undistorted_keypoints)
def calculate_error(
fundamental_matrix,
keypoints_1,
keypoints_2,
camera_matrices_1: dict = None,
camera_matrices_2: dict = None,
):
"""Calculate the error between two sets of keypoints.
Args:
fundamental_matrix: The fundamental matrix F. Shape (3, 3).
keypoints_1: The keypoints in the first view. Shape (n_points, 2).
keypoints_2: The keypoints in the second view. Shape (n_points, 2).
camera_matrices_1: The camera matrices for camera 1.
camera_matrices_2: The camera matrices for camera 2.
Returns:
The error between the two sets of keypoints. Shape (n_points,).
"""
camera_matrices_1 = {} if camera_matrices_1 is None else camera_matrices_1
camera_matrices_2 = {} if camera_matrices_2 is None else camera_matrices_2
keypoints_1 = undistort_and_homogenize_keypoints(
keypoints=keypoints_1,
camera_matrix=camera_matrices_1.get("matrix", None),
distortion_coefficients=camera_matrices_1.get("distortions", None),
)
keypoints_2 = undistort_and_homogenize_keypoints(
keypoints=keypoints_2,
camera_matrix=camera_matrices_2.get("matrix", None),
distortion_coefficients=camera_matrices_2.get("distortions", None),
)
error = (
keypoints_2.transpose(1, 0, 2)
@ fundamental_matrix
@ np.transpose(keypoints_1, (1, 2, 0))
)
nanmean_error = np.abs(np.nanmean(error, axis=0))
max_error = np.nanmax(nanmean_error)
neg_error = -(max_error - nanmean_error)
return neg_error
def plot_adjacency_matrices(adjacency_matrices: dict):
# Plot the correlation matrix
## Calculate the number of rows and columns needed
num_plots = len(adjacency_matrices)
num_cols = 2
num_rows = (num_plots + 1) // num_cols
# Create subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(7, 7))
# Flatten the axes array for easy iteration
axes = axes.flatten()
for i, (title, (matrix, row_ind, col_ind)) in enumerate(adjacency_matrices.items()):
ax = axes[i]
sns.heatmap(
matrix,
annot=True,
fmt=".2f",
# cmap="coolwarm",
linewidths=0.5,
ax=ax,
annot_kws={"fontsize": 8}, # Adjust the fontsize here
)
ax.set_title(f"{title}")
x_label, y_label = title.split(" and ")
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
# Outline the blocks
for r, c in zip(row_ind, col_ind):
rect = patches.Rectangle(
(c, r), 1, 1, linewidth=3, edgecolor=(0, 1, 0), facecolor="none"
)
ax.add_patch(rect)
# Hide any unused subplots
for j in range(num_plots, len(axes)):
fig.delaxes(axes[j])
# Set a title for the entire figure
fig.suptitle(
r"Adjacency Matrices $e_{neg}$ for Multiview Association"
"\n"
r"$e = \mathbf{\hat{p}^T} \mathbf{F} \mathbf{p}$"
"\t"
r"$e_{neg} = |e|-\max|e|$",
fontsize=16,
)
plt.tight_layout()
plt.show()
def evaluate_camera_pair(cam_1, cam_2, keypoints_1, keypoints_2):
"""Evaluate the pair of cameras.
Args:
cam_1: The first camera.
cam_2: The second camera.
keypoints_1: The keypoints in the first view. Shape (n_poses_1, n_points, 2).
keypoints_2: The keypoints in the second view. Shape (n_poses_2, n_points, 2).
Returns:
The error between the two sets of keypoints represented as an adjacency matrix
of shape (n_poses_2, n_poses_1). Also returns the row and column indices of the
minimum error for each pose
"""
# Retrieve the camera matrices for the pair of cameras
cam_matrices_1 = cam_1.get_dict()
cam_matrices_2 = cam_2.get_dict()
# Calculate the fundamental matrix for the pair of cameras
fundamental_matrix = calculate_fundamental_matrix(
cam_matrices_1=cam_matrices_1, cam_matrices_2=cam_matrices_2
)
print(f"Fundamental matrix:\n{fundamental_matrix}", end="\n\n")
# For each multi-animal grouping, calculate x_1 * F * x_2 where x is T x N x 3
error = calculate_error(
fundamental_matrix=fundamental_matrix,
keypoints_1=keypoints_1,
keypoints_2=keypoints_2,
camera_matrices_1=cam_matrices_1,
camera_matrices_2=cam_matrices_2,
)
print(f"Error:\n{error}", end="\n\n")
# Find rows and columns with NaN values
nan_rows = np.all(np.isnan(error), axis=1)
nan_cols = np.all(np.isnan(error), axis=0)
# Find which multi-animal grouping minimizes x_1 * F * x_2
row_ind, col_ind = linear_sum_assignment(error[~nan_rows][:, ~nan_cols])
print(f"Row indices: {row_ind}\nColumn indices: {col_ind}", end="\n\n")
# Adjust row and column indices to account for NaN values
row_inds = []
col_inds = []
for row, col in zip(row_ind, col_ind):
row_offset = np.sum(nan_rows[: row + 1])
col_offset = np.sum(nan_cols[: col + 1])
row_inds.append(row + row_offset)
col_inds.append(col + col_offset)
return error, row_inds, col_inds
def main(session: RecordingSession, frame_idx: int = None):
# TODO(LM): Use camera extrinsics to find pairs of cameras with similar views
# Pair cameras and connect camera pairs
camera_pairs = [
("side", "sideL"), # Similar views
("top", "topL"),
("mid", "midL"),
("side", "mid"), # Connecting views
("top", "mid"),
]
camera_pairs = [
("back", "left"), # Similar views
("left", "right"),
("right", "front"),
("front", "center"),
]
# TODO(LM): Add method to return all ungrouped coordinates
if frame_idx is None:
frame_idx = next(iter(session.frame_groups.keys()))
frame_group_numpy = session.frame_groups[frame_idx].numpy()
errors = {}
for pair_idx in range(len(camera_pairs)):
# Retrieve the Camcorder objects for the pair of cameras
cam_1, cam_2 = get_camera_pairs(session.camera_cluster, camera_pairs[pair_idx])
# Get the keypoints for the pair of cameras
cam_idx_1 = session.cams_to_include.index(cam_1)
cam_idx_2 = session.cams_to_include.index(cam_2)
keypoints_1 = frame_group_numpy[cam_idx_1]
keypoints_2 = frame_group_numpy[cam_idx_2]
# Evaluate the pair of cameras
error, row_ind, col_ind = evaluate_camera_pair(
cam_1, cam_2, keypoints_1, keypoints_2
)
errors[f"{cam_1.name} and {cam_2.name}"] = (error, row_ind, col_ind)
plot_adjacency_matrices(errors)
if __name__ == "__main__":
# Load project
import os
# ds = os.environ["dsmview"]
ds = os.environ["dsmviewgerbils"]
labels = sleap.load_file(ds)
session = labels.sessions[0]
main(session=session)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment