Last active
August 29, 2015 13:57
-
-
Save has2k1/9637948 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class geom(object): | |
"""Base class of all geoms""" | |
DEFAULT_AES = dict() | |
REQUIRED_AES = set() | |
DEFAULT_PARAMS = dict() | |
data = None | |
aes = None | |
manual_aes = None | |
params = None | |
def __init__(self, *args, **kwargs): | |
# assign data, aes, manual_aes, params | |
pass | |
def __radd__(self, gg): | |
# add layer to ggplot object | |
pass | |
def _rename_aes(self, data, translations): | |
# helper function for the geoms | |
# to convert from ggplot2 api to matplotib | |
return data | |
def plot_layer(self, data, ax): | |
# abstract function to be implemented by each geom | |
pass | |
# | |
class geom_point(geom): | |
DEFAULT_AES = {'alpha': 1, 'color': 'black', 'fill': None, | |
'shape': 'o', 'size': 20} | |
REQUIRED_AES = {'x', 'y'} | |
DEFAULT_PARAMS = {'stat': 'identity', 'position': 'identity', | |
'cmap':None, 'label': ''} | |
def plot_layer(self, data, ax): | |
translations = {'size': 's', 'shape': 'marker', | |
'color': 'edgecolor', 'fill': 'color'} | |
# TODO: Not sure if position adjustments are applied before or | |
# after the grouping. | |
if self.params['position'] == 'jitter': | |
data = self._jitter(data) | |
_groups = set(self.aes) & {'color', 'fill', 'shape', 'alpha', 'size'} | |
for key, _data in data.groupby(_groups): | |
_data = _data.to_dict() | |
_data = self._rename_aes(_data, translations) | |
ax.scatter(**_data) | |
class stat(object): | |
"""Base class of all stats""" | |
REQUIRED_AES = set() | |
DEFAULT_PARAMS = dict() | |
CREATES = set() # extra columns created by the stat | |
data = None | |
aes = None | |
params = None | |
def __init__(self, *args, **kwargs): | |
# assign data, aes, params | |
pass | |
def __radd__(self, gg): | |
# add layer to ggplot object | |
pass | |
def compute(self, data): | |
# abstract function to be implemented by each | |
return data | |
# An example of a stat | |
class stat_identity(stat): | |
DEFAULT_PARAMS = {'geom': 'point', 'position': 'identity', | |
'width': None, 'height': None} | |
def compute(self, data): | |
return data | |
class position(object): | |
"""Base class for all positions""" | |
# Aesthetics that map onto the x and y scales | |
X = {'x', 'xmin', 'xmax', 'xend', 'xintercept'} | |
Y = {'y', 'ymin', 'xmax', 'yend', 'yintercept'} | |
def __init__(self, width=None, height=None, **kwargs): | |
self.width = kwargs.get('w', width) | |
self.height = kwargs.get('h', height) | |
def adjust(self, data): | |
""" | |
Positions must override this function | |
How? | |
---- | |
Make necessary adjustments the columns in the dataframe. | |
Create the position transformation functions and | |
use self._transform_position() do the rest. | |
See: position_jitter.adjust() | |
""" | |
return data | |
def _transform_position(self, data, trans_x=None, trans_y=None): | |
""" | |
Transform all the variables map onto the x and y scales. | |
Parameters | |
---------- | |
data : dataframe | |
trans_x : function | |
Transforms x scale mappings | |
Takes one argument, either a scalar or an array-type | |
trans_y : function | |
Transforms y scale mappings | |
Takes one argument, either a scalar or an array-type | |
Helper function for self.adjust | |
""" | |
if trans_x: | |
xs = filter(lambda name: name in self.X, data.columns) | |
data[xs] = data[x].apply(trans_x) | |
if trans_y: | |
ys = filter(lambda name: name in self.Y, data.columns) | |
data[ys] = data[ys].apply(trans_y) | |
return data | |
class position_identity(position): | |
pass | |
class position_jitter(position): | |
def adjust(self, data): | |
if not self.width: | |
self.width = resolution(data['x']) * .4 | |
if not self.height: | |
self.height = resolution(data['y']) * .4 | |
trans_x = None | |
trans_y = None | |
if self.width: | |
trans_x = lambda x: jitter(x, self.width) | |
if self.height: | |
trans_y = lambda y: jitter(y, self.height) | |
return self._transform_position(data, trans_x, trans_y) | |
########### Ported functions for position='jitter' ############## | |
def resolution(x, zero=True): | |
""" | |
Compute the resolution of a data vector | |
Resolution is smallest non-zero distance between adjacent values | |
Parameters | |
---------- | |
x : 1D array_like | |
zero : Boolean | |
Whether to include zero values in the computation | |
Result | |
------ | |
res : resolution of x | |
If x is an integer array, then the resolution is 1 | |
""" | |
if isinstance(x, list, tuple): | |
x = np.array(x) | |
# (unsigned) integers or an effective range of zero | |
if (x.dtype.kind in ('i', 'u') or | |
x.ptp() < np.finfo(float).resolution()): | |
return 1 | |
if not zero: | |
x = x[x!=0] | |
return min(np.diff(np.sort(x))) | |
def jitter(x, factor=1, amount=None): | |
""" | |
Add a small amount of noise to values in an array_like | |
""" | |
if len(x) == 0: | |
return x | |
if isinstance(x, (list, tuple)): | |
x = np.array(x) | |
try: | |
z = np.ptp(x[np.isfinite(x)]) | |
except IndexError: | |
z = 0 | |
if z == 0: | |
z = abs(min(x)) | |
if z == 0: | |
z = 1 | |
if amount is None: | |
_x = np.round(x, 3-np.int(np.floor(np.log10(z)))).astype(np.int) | |
xx = np.unique(np.sort(_x)) | |
d = np.diff(xx) | |
if len(d): | |
d = min(d) | |
elif xx != 0: | |
d = xx/10. | |
else: | |
d = z/10 | |
amount = factor/5. * abs(d) | |
elif amount == 0: | |
amount = factor * (z / 50.) | |
return x + np.random.uniform(-amount, amount, len(x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment