Skip to content

Instantly share code, notes, and snippets.

@CalebFenton
Created November 23, 2016 19:23
Show Gist options
  • Save CalebFenton/5045c848f65538fe8cc2287ef9fb8179 to your computer and use it in GitHub Desktop.
Save CalebFenton/5045c848f65538fe8cc2287ef9fb8179 to your computer and use it in GitHub Desktop.
Graph the results of a Sklearn gridsearch
#!/usr/bin/env python
import collections
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import sklearn as skl
import itertools
def get_data(gs):
graph_data = {}
print(gs.grid_scores_)
for gscore in gs.grid_scores_:
key = 'features: %s, split: %s, oob: %s' % (gscore[0]['max_features'], gscore[0]['min_samples_split'], gscore[0]['oob_score'])
x = gscore[0]['n_estimators']
y = gscore[1]
if key in graph_data:
graph_data[key]['x'].append(x)
graph_data[key]['y'].append(y)
else:
graph_data[key] = {'x': [x], 'y': [y]}
return collections.OrderedDict(sorted(graph_data.items()))
def main():
# This loads a grid search which has been executed and saved with joblib.save().
gs = skl.externals.joblib.load('gridsearch.pkl')
graph_data = get_data(gs)
num_plots = len(graph_data)
plt.figure(1)
colormap = plt.cm.gist_ncar
plt.gca().set_color_cycle([colormap(i) for i in np.linspace(0, 0.9, num_plots)])
marker = itertools.cycle(('o', 'v', '^', '<', '>', 's', '8', 'p', '*', 'D', 'x', 'H'))
labels = []
for k, v in graph_data.iteritems():
labels.append(k)
print(v)
plt.plot(v['x'], v['y'], marker=marker.next(), clip_on=False)
plt.legend(labels, ncol=4, loc='upper center',
bbox_to_anchor=[0.5, 1.1],
columnspacing=1.0, labelspacing=0.0,
handletextpad=0.0, handlelength=1.5,
fancybox=True, shadow=True)
plt.xlabel('# Estimators')
plt.ylabel('Mean score')
plt.grid(True)
ax = plt.gca()
ax.get_yaxis().get_major_formatter().set_useOffset(False)
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment