Skip to content

Instantly share code, notes, and snippets.

@moorepants
Created January 25, 2023 10:06
Show Gist options
  • Save moorepants/983142d69e87978f669e6ce5099ba49b to your computer and use it in GitHub Desktop.
Save moorepants/983142d69e87978f669e6ce5099ba49b to your computer and use it in GitHub Desktop.
Script that runs a weighted k-means on overlapping rectangles on a grid to find clusters of "focus"
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as pat
import sklearn.cluster as clu
# set the seed if you want the same random numbers on eac execution
#np.random.seed(5)
# generate data for random overlapping rectangles
num_rectangles = 50
widths = np.random.random_integers(2, 8, num_rectangles)
heights = np.random.random_integers(2, 8, num_rectangles)
locations_x = np.random.random_integers(0, 40, num_rectangles)
locations_y = np.random.random_integers(0, 20, num_rectangles)
# make a figure to plot on
fig, ax = plt.subplots(1, 1)
ax.set_xlim((0, 50))
ax.set_ylim((0, 30))
# create to matrix of zeros to represent each (x, y) coordinate in the grid and
# will hold a 1 if the (x, y) contains one or more rectangles and a 0 if not
filled = np.zeros((60, 40))
# create a matrix of zeros to represent each (x, y) coordinate in the grid that
# will hold a value representing how many rectangles overlap at that location
weights = np.zeros_like(filled)
for i in range(num_rectangles):
x, y = locations_x[i], locations_y[i]
w, h = widths[i], heights[i]
# set all locations that have a rectangle to 1
filled[x:x+w+1, y:y+h+1] = 1
# set all locations that have overlapping rectangles to the number of
# overlaps
weights[x:x+w+1, y:y+h+1] += 1
# create a rectangle patch to plot
rect = pat.Rectangle(
(x, y), w, h,
edgecolor='grey',
facecolor='grey',
alpha=0.5)
ax.add_patch(rect)
# convert the grid of filled locations to a pairs of (x,y) coordinates where a
# rectangle is present
X = np.asarray(np.nonzero(filled)).T
# assign the number of overlaps at each row of X as the weight
W = np.asarray([weights[pair[0], pair[1]] for pair in X])
# simple weighted K-Means cluster
kmeans = clu.KMeans(n_clusters=5).fit_predict(X, sample_weight=W)
# plot the clusters
ax.scatter(X[:, 0], X[:, 1], c=kmeans)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment