Created
March 1, 2019 20:22
-
-
Save salotz/8b4542d7fe9ea3e2eacc1a2eef2532c5 to your computer and use it in GitHub Desktop.
Move a matplotlib Axes from one figure to another.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
def move_axes(ax, fig, subplot_spec=111): | |
"""Move an Axes object from a figure to a new pyplot managed Figure in | |
the specified subplot.""" | |
# get a reference to the old figure context so we can release it | |
old_fig = ax.figure | |
# remove the Axes from it's original Figure context | |
ax.remove() | |
# set the pointer from the Axes to the new figure | |
ax.figure = fig | |
# add the Axes to the registry of axes for the figure | |
fig.axes.append(ax) | |
# twice, I don't know why... | |
fig.add_axes(ax) | |
# then to actually show the Axes in the new figure we have to make | |
# a subplot with the positions etc for the Axes to go, so make a | |
# subplot which will have a dummy Axes | |
dummy_ax = fig.add_subplot(subplot_spec) | |
# then copy the relevant data from the dummy to the ax | |
ax.set_position(dummy_ax.get_position()) | |
# then remove the dummy | |
dummy_ax.remove() | |
# close the figure the original axis was bound to | |
plt.close(old_fig) |
It was using a pandas.DataFrame.plot() return value of a numpy array of Axes
more details:
figsize = (23, 11.5)
nrows = 4
ncols = 7
all_axes = []
for category in categories:
# group/filter a dataframe based on the category. assume it is df here
axes = df.plot(kind='pie', subplots=True, layout=(nrows, ncols), figsize=figsize, legend=False, ylabel='', title='example')
all_axes.append((category, axes))
fig = figure(figsize=figsize)
# subplots_adjust(...) as needed
old_figs = []
for i_row, (cat, axes) in enumerate(all_axes):
for i_col, ax in enumerate(axes[0]):
old_fig = move_axes(ax, fig, (nrows, ncols, 1 + i_col + i_row * 7))
if i_col == 0:
old_figs.append(old_fig)
ax.text(-0.25, 0.5, cat, transform=ax.transAxes, rotation=90, ha='center', va='center')
for old_fig in old_figs:
plt.close(old_fig)
sorry it's not a complete example, but hopefully it works if you play with it
also this is using my fork if it was unclear https://gist.github.com/digitalsignalperson/546e80ae1965b83df0a82ba12ae8aac7
Right, I see. Thanks!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@digitalsignalperson What does
make_complicated_plot()
return? It seems like it should be amatplotlib.axes.Axes
, but then you computelen(axes[0])
, so the return ofmake_complicated_plot()
must then be subscriptable.