Created
October 8, 2024 15:52
-
-
Save roomrys/a053203a511764da871259dfd982fb34 to your computer and use it in GitHub Desktop.
SLEAP: Multiview Association via Pairs of Views and Fundamental Matrix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This module implements cycle consistent matching using pairs of views.""" | |
from __future__ import annotations | |
from typing import Generator | |
import cv2 | |
import matplotlib.patches as patches | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
from scipy.optimize import linear_sum_assignment | |
import sleap | |
from sleap.io.cameras import Camcorder, RecordingSession | |
def undistort_keypoints(keypoints, camera_matrix, distortion_coefficients): | |
"""Undistort keypoints using the camera matrix and distortion coefficients. | |
Args: | |
keypoints: The keypoints to undistort. Shape (n_points, 2). | |
camera_matrix: The camera matrix. Shape (3, 3). | |
distortion_coefficients: The distortion coefficients. Shape (5,). | |
Returns: | |
The undistorted keypoints. Shape (n_points, 2). | |
""" | |
if isinstance(keypoints, list): | |
keypoints = np.array(keypoints) | |
if isinstance(camera_matrix, list): | |
camera_matrix = np.array(camera_matrix) | |
if isinstance(distortion_coefficients, list): | |
distortion_coefficients = np.array(distortion_coefficients) | |
original_shape = keypoints.shape | |
reshaped_keypoints = keypoints.reshape(-1, 2) | |
undistorted_keypoints = cv2.undistortPoints( | |
reshaped_keypoints, camera_matrix, distortion_coefficients, P=camera_matrix | |
) | |
return undistorted_keypoints.reshape(original_shape) | |
def get_camera_pairs(camera_cluster, names) -> list[Camcorder]: | |
"""Generator to get pairs of cameras from the camera cluster. | |
Args: | |
camera_cluster: The camera cluster containing the cameras. | |
names: The names of the cameras to pair. | |
Yields: | |
A pair of cameras. | |
""" | |
cameras = [] | |
for name in names: | |
try: | |
camera = next(find_camera_by_name(camera_cluster, name)) | |
except StopIteration: | |
raise ValueError( | |
f"Camera {name} not found in the camera cluster: {camera_cluster.cameras}" | |
) | |
cameras.append(camera) | |
return cameras | |
def find_camera_by_name(camera_cluster, name) -> Generator[Camcorder]: | |
"""Generator to find a camera by its name in the camera cluster. | |
Args: | |
camera_cluster: The camera cluster containing the cameras. | |
camera_name: The name of the camera to find. | |
Yields: | |
The camera with the specified name. | |
""" | |
for camera in camera_cluster: | |
if camera.name == name: | |
yield camera | |
def inverse_of_upper_triangular(matrix): | |
"""Calculate the inverse of an upper triangular matrix. | |
Args: | |
matrix: The upper triangular matrix. Shape (n, n). | |
""" | |
if isinstance(matrix, list): | |
matrix = np.array(matrix) | |
n = matrix.shape[0] | |
inverse = np.zeros((n, n)) | |
# Iterate over the diagonal elements | |
for i in range(n): | |
inverse[i, i] = 1 / matrix[i, i] | |
# Iterate over the off-diagonal elements | |
for i in range(n): | |
for j in range(i + 1, n): | |
sum_product = 0 | |
for k in range(i, j): | |
sum_product += matrix[i, k] * inverse[k, j] | |
inverse[i, j] = -sum_product / matrix[j, j] | |
return inverse | |
def calculate_essential_matrix(rotation, translation): | |
"""Calculate the essential matrix from rotation and translation. | |
Args: | |
rotation: The rotation matrix between two cameras. Shape (3, 3). | |
translation: The translation vector between two cameras. | |
Returns: | |
The essential matrix. Shape (3, 3). | |
""" | |
t_cross_top = np.array( | |
[ | |
[0, -translation[2], translation[1]], | |
[0, 0, -translation[0]], | |
[0, 0, 0], | |
] | |
) | |
t_cross = t_cross_top - t_cross_top.T | |
essential_matrix = t_cross @ rotation | |
return essential_matrix | |
def enforce_rotation_matrix(rot): | |
"""Enforce that the input matrix is a rotation matrix. | |
Args: | |
rot: The input matrix. Shape (3, 3) or (3, 1). | |
""" | |
if isinstance(rot, list): | |
rot = np.array(rot) | |
if rot.shape != (3, 3): | |
if rot.shape == (3,): | |
# Assuming axis-angle representation, convert to rotation matrix | |
rot = cv2.Rodrigues(rot)[0] | |
else: | |
raise ValueError( | |
"Rotation matrix must be of shape (3, 3). Recieved: ", rot.shape | |
) | |
return rot | |
def unrotate(points, rotation): | |
"""Unrotate points using the rotation matrix. | |
Mutiplying the rotation on the left to rotates the points in the opposite direction. | |
Args: | |
points: The points to unrotate. Shape (n_points, 3). points = rotation @ output. | |
rotation: The rotation matrix. Shape (3, 3). | |
Returns: | |
The unrotated points. Shape (n_points, 3). output = points @ rotation. | |
""" | |
rotation = enforce_rotation_matrix(rotation) | |
return points @ rotation | |
def calculate_fundamental_matrix(cam_matrices_1, cam_matrices_2): | |
"""Calculate the fundamental matrix between two cameras. | |
Args: | |
cam_matrices_1: The camera matrices for camera 1. | |
cam_matrices_2: The camera matrices for camera 2. | |
Returns: | |
The fundamental matrix F. Shape (3, 3). x_2.T @ F @ x_1 = 0. | |
""" | |
# Retrieve the rotation from world to camera | |
rot_1 = enforce_rotation_matrix(cam_matrices_1["rotation"]) | |
# Retrieve the unrotated translation from world to camera | |
translation_1 = unrotate(cam_matrices_1["translation"], rot_1) | |
# Calculate the essential matrix | |
essential_matrix_1 = calculate_essential_matrix(rot_1, translation_1) | |
# Calculate the fundamental matrix | |
inverse_intrinsics_1 = inverse_of_upper_triangular(cam_matrices_1["matrix"]) | |
inverse_intrinsics_2 = inverse_of_upper_triangular(cam_matrices_2["matrix"]) | |
fundamental_matrix = ( | |
inverse_intrinsics_2.T @ essential_matrix_1 @ inverse_intrinsics_1 | |
) | |
return fundamental_matrix | |
def get_homogenous_keypoints(keypoints): | |
"""Convert 2D keypoints to homogenous coordinates. | |
Args: | |
keypoints: The keypoints. Shape (n_points, 2). | |
Returns: | |
The keypoints in homogenous coordinates. Shape (n_points, 3). | |
""" | |
if isinstance(keypoints, list): | |
keypoints = np.array(keypoints) | |
if keypoints.shape[-1] == 3: | |
print("Keypoints are already in homogenous coordinates.") | |
return keypoints | |
if len(keypoints.shape) == 3: | |
ones_shape = (keypoints.shape[:2]) + (1,) | |
stack = np.dstack | |
else: | |
# len(keypoints.shape) == 2 | |
ones_shape = (keypoints.shape[0], 1) | |
stack = np.hstack | |
return stack((keypoints, np.ones(ones_shape))) | |
def undistort_and_homogenize_keypoints( | |
keypoints, camera_matrix, distortion_coefficients | |
): | |
"""Undistort and homogenize keypoints using the camera matrix and distortion coefficients. | |
Args: | |
keypoints: The keypoints to undistort. Shape (n_points, 2). | |
camera_matrix: The camera matrix. Shape (3, 3). | |
distortion_coefficients: The distortion coefficients. Shape (5,). | |
Returns: | |
The undistorted and homogenized keypoints. Shape (n_points, 3). | |
""" | |
if camera_matrix is None or distortion_coefficients is None: | |
undistorted_keypoints = keypoints | |
else: | |
undistorted_keypoints = undistort_keypoints( | |
keypoints, camera_matrix, distortion_coefficients | |
) | |
return get_homogenous_keypoints(undistorted_keypoints) | |
def calculate_error( | |
fundamental_matrix, | |
keypoints_1, | |
keypoints_2, | |
camera_matrices_1: dict = None, | |
camera_matrices_2: dict = None, | |
): | |
"""Calculate the error between two sets of keypoints. | |
Args: | |
fundamental_matrix: The fundamental matrix F. Shape (3, 3). | |
keypoints_1: The keypoints in the first view. Shape (n_points, 2). | |
keypoints_2: The keypoints in the second view. Shape (n_points, 2). | |
camera_matrices_1: The camera matrices for camera 1. | |
camera_matrices_2: The camera matrices for camera 2. | |
Returns: | |
The error between the two sets of keypoints. Shape (n_points,). | |
""" | |
camera_matrices_1 = {} if camera_matrices_1 is None else camera_matrices_1 | |
camera_matrices_2 = {} if camera_matrices_2 is None else camera_matrices_2 | |
keypoints_1 = undistort_and_homogenize_keypoints( | |
keypoints=keypoints_1, | |
camera_matrix=camera_matrices_1.get("matrix", None), | |
distortion_coefficients=camera_matrices_1.get("distortions", None), | |
) | |
keypoints_2 = undistort_and_homogenize_keypoints( | |
keypoints=keypoints_2, | |
camera_matrix=camera_matrices_2.get("matrix", None), | |
distortion_coefficients=camera_matrices_2.get("distortions", None), | |
) | |
error = ( | |
keypoints_2.transpose(1, 0, 2) | |
@ fundamental_matrix | |
@ np.transpose(keypoints_1, (1, 2, 0)) | |
) | |
nanmean_error = np.abs(np.nanmean(error, axis=0)) | |
max_error = np.nanmax(nanmean_error) | |
neg_error = -(max_error - nanmean_error) | |
return neg_error | |
def plot_adjacency_matrices(adjacency_matrices: dict): | |
# Plot the correlation matrix | |
## Calculate the number of rows and columns needed | |
num_plots = len(adjacency_matrices) | |
num_cols = 2 | |
num_rows = (num_plots + 1) // num_cols | |
# Create subplots | |
fig, axes = plt.subplots(num_rows, num_cols, figsize=(7, 7)) | |
# Flatten the axes array for easy iteration | |
axes = axes.flatten() | |
for i, (title, (matrix, row_ind, col_ind)) in enumerate(adjacency_matrices.items()): | |
ax = axes[i] | |
sns.heatmap( | |
matrix, | |
annot=True, | |
fmt=".2f", | |
# cmap="coolwarm", | |
linewidths=0.5, | |
ax=ax, | |
annot_kws={"fontsize": 8}, # Adjust the fontsize here | |
) | |
ax.set_title(f"{title}") | |
x_label, y_label = title.split(" and ") | |
ax.set_xlabel(x_label) | |
ax.set_ylabel(y_label) | |
# Outline the blocks | |
for r, c in zip(row_ind, col_ind): | |
rect = patches.Rectangle( | |
(c, r), 1, 1, linewidth=3, edgecolor=(0, 1, 0), facecolor="none" | |
) | |
ax.add_patch(rect) | |
# Hide any unused subplots | |
for j in range(num_plots, len(axes)): | |
fig.delaxes(axes[j]) | |
# Set a title for the entire figure | |
fig.suptitle( | |
r"Adjacency Matrices $e_{neg}$ for Multiview Association" | |
"\n" | |
r"$e = \mathbf{\hat{p}^T} \mathbf{F} \mathbf{p}$" | |
"\t" | |
r"$e_{neg} = |e|-\max|e|$", | |
fontsize=16, | |
) | |
plt.tight_layout() | |
plt.show() | |
def evaluate_camera_pair(cam_1, cam_2, keypoints_1, keypoints_2): | |
"""Evaluate the pair of cameras. | |
Args: | |
cam_1: The first camera. | |
cam_2: The second camera. | |
keypoints_1: The keypoints in the first view. Shape (n_poses_1, n_points, 2). | |
keypoints_2: The keypoints in the second view. Shape (n_poses_2, n_points, 2). | |
Returns: | |
The error between the two sets of keypoints represented as an adjacency matrix | |
of shape (n_poses_2, n_poses_1). Also returns the row and column indices of the | |
minimum error for each pose | |
""" | |
# Retrieve the camera matrices for the pair of cameras | |
cam_matrices_1 = cam_1.get_dict() | |
cam_matrices_2 = cam_2.get_dict() | |
# Calculate the fundamental matrix for the pair of cameras | |
fundamental_matrix = calculate_fundamental_matrix( | |
cam_matrices_1=cam_matrices_1, cam_matrices_2=cam_matrices_2 | |
) | |
print(f"Fundamental matrix:\n{fundamental_matrix}", end="\n\n") | |
# For each multi-animal grouping, calculate x_1 * F * x_2 where x is T x N x 3 | |
error = calculate_error( | |
fundamental_matrix=fundamental_matrix, | |
keypoints_1=keypoints_1, | |
keypoints_2=keypoints_2, | |
camera_matrices_1=cam_matrices_1, | |
camera_matrices_2=cam_matrices_2, | |
) | |
print(f"Error:\n{error}", end="\n\n") | |
# Find rows and columns with NaN values | |
nan_rows = np.all(np.isnan(error), axis=1) | |
nan_cols = np.all(np.isnan(error), axis=0) | |
# Find which multi-animal grouping minimizes x_1 * F * x_2 | |
row_ind, col_ind = linear_sum_assignment(error[~nan_rows][:, ~nan_cols]) | |
print(f"Row indices: {row_ind}\nColumn indices: {col_ind}", end="\n\n") | |
# Adjust row and column indices to account for NaN values | |
row_inds = [] | |
col_inds = [] | |
for row, col in zip(row_ind, col_ind): | |
row_offset = np.sum(nan_rows[: row + 1]) | |
col_offset = np.sum(nan_cols[: col + 1]) | |
row_inds.append(row + row_offset) | |
col_inds.append(col + col_offset) | |
return error, row_inds, col_inds | |
def main(session: RecordingSession, frame_idx: int = None): | |
# TODO(LM): Use camera extrinsics to find pairs of cameras with similar views | |
# Pair cameras and connect camera pairs | |
camera_pairs = [ | |
("side", "sideL"), # Similar views | |
("top", "topL"), | |
("mid", "midL"), | |
("side", "mid"), # Connecting views | |
("top", "mid"), | |
] | |
camera_pairs = [ | |
("back", "left"), # Similar views | |
("left", "right"), | |
("right", "front"), | |
("front", "center"), | |
] | |
# TODO(LM): Add method to return all ungrouped coordinates | |
if frame_idx is None: | |
frame_idx = next(iter(session.frame_groups.keys())) | |
frame_group_numpy = session.frame_groups[frame_idx].numpy() | |
errors = {} | |
for pair_idx in range(len(camera_pairs)): | |
# Retrieve the Camcorder objects for the pair of cameras | |
cam_1, cam_2 = get_camera_pairs(session.camera_cluster, camera_pairs[pair_idx]) | |
# Get the keypoints for the pair of cameras | |
cam_idx_1 = session.cams_to_include.index(cam_1) | |
cam_idx_2 = session.cams_to_include.index(cam_2) | |
keypoints_1 = frame_group_numpy[cam_idx_1] | |
keypoints_2 = frame_group_numpy[cam_idx_2] | |
# Evaluate the pair of cameras | |
error, row_ind, col_ind = evaluate_camera_pair( | |
cam_1, cam_2, keypoints_1, keypoints_2 | |
) | |
errors[f"{cam_1.name} and {cam_2.name}"] = (error, row_ind, col_ind) | |
plot_adjacency_matrices(errors) | |
if __name__ == "__main__": | |
# Load project | |
import os | |
# ds = os.environ["dsmview"] | |
ds = os.environ["dsmviewgerbils"] | |
labels = sleap.load_file(ds) | |
session = labels.sessions[0] | |
main(session=session) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment