|
# -*- encoding: utf-8 -*- |
|
|
|
""" |
|
A set of Utility Function(s) to Visualize Correlation of a DataFrame |
|
|
|
"Correlation is a statistical measure that expresses the extent to |
|
which two variables are linearly related." Popular methods' like |
|
`Pearson's R` quantifies the strength of relation between features. |
|
Programatically, there are various in-built function like `pd.corr()` |
|
which calculates correlation. |
|
|
|
Often, it is easier to just visualize the correlation information |
|
with the help of Heat Maps to understand the relationship for many |
|
variables. The code uses `seaborn` and `matplotlib` libraries to |
|
showcase correlation heat-map of a dataframe. |
|
|
|
@author: Debmalya Pramanik |
|
""" |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
|
|
def corr_heatmap( |
|
df : pd.DataFrame, |
|
plot_bar : bool = True, |
|
target_column : str = None, |
|
annot : bool = True, |
|
annot_thresh : float = 0.65, |
|
orient : str = "v", |
|
tri : str = "all", |
|
**kwargs |
|
) -> plt.Figure: |
|
""" |
|
Calculate Correlation of a DataFrame (`df`) and Plot Heat-Map |
|
|
|
The function uses the in-built `pandas` function `pd.corr()` to |
|
calculate correlation of all numeric columns and then returns a |
|
visualization heatmap using `seaborn` for better understanding of |
|
inter-relationship between all features. |
|
|
|
:type df: object |
|
:param df: Original dataframe on which correlation heatmap is to |
|
be performed. The function uses the in-built pandas |
|
method. |
|
|
|
:type plot_bar: bool |
|
:param plot_bar: Bar plot of correlation of `target_column` |
|
against all other related numeric features. |
|
! this parameter is not used currently |
|
TODO set controls to plot bar based on user |
|
|
|
:type target_column: str |
|
:param target_column: Name of the target column when plotting |
|
correlation values in bar plot against all |
|
other features available in the dataframe. |
|
If `plot_bar == True` then this parameter |
|
is required, else ignored by the function. |
|
|
|
:type annot: bool |
|
:param annot: Annotate heatmap labels. Defaults to True. When set |
|
to True, then `annot_thresh` keyword arguments can |
|
be used to control the nature of the map. |
|
|
|
:type annot_thresh: float |
|
:param annot_thresh: Threshold value for annotation. The range of |
|
value is [0, 1] i.e. for any given value (x) |
|
the annotation is done only when correlation |
|
is greater than or equal to `x` or less than |
|
or equal to `-x`. Defaults to 0.65. This |
|
parameter is passed to `corr_barplot()` as |
|
`threshold` parameter for deciding important |
|
features. |
|
|
|
:type orient: str |
|
:param orient: Orientation of the bar plot (h|v) of correlated |
|
terms whose value is `abs(corr) >= annot_thresh`. |
|
Defaults to `v` i.e. heatmap and barplot is |
|
stacked vertically, else pass `h` for horizontal |
|
stacking. |
|
|
|
:type tri: str |
|
:param tri: An array with ones at and below the given diagonal |
|
and zeros elsewhere. Accepted value (all|upper|lower) |
|
are used as `mask` parameter to `sns.heatmap`, and |
|
defaults to `all` i.e. both upper and lower triangle |
|
alongwith the diagonal is displayed in heatmap. |
|
|
|
Keyword Arguments |
|
----------------- |
|
The function accepts alomst all keyword arguments accepted by |
|
`pd.corr()` and `sns.heatmap()` method. Additionally, the |
|
behaviour of the plot and correlation can be controlled with the |
|
below arguments. |
|
|
|
* *method* (`str`): Method of correlation. Accepts all values |
|
as supported by `df.corr()` function. |
|
* *min_periods* (`int`): Minimum number of observations |
|
required per pair of columns to have a valid result. Check |
|
documentation of `df.corr()` for more information. |
|
* *round* (`int`): Round a number to a given precision in |
|
decimal digits. Typicall used in plot annotations. |
|
* *vmin* (floats): Ass accepted by `sns.heatmap` function. |
|
* *vmax* (floats): Ass accepted by `sns.heatmap` function. |
|
* *cmap* (floats): Ass accepted by `sns.heatmap` function. |
|
* *cbar* (floats): Ass accepted by `sns.heatmap` function. |
|
* *square* (floats): Ass accepted by `sns.heatmap` function. |
|
* *linecolor* (floats): Ass accepted by `sns.heatmap` function. |
|
* *linewidths* (floats): Ass accepted by `sns.heatmap` function. |
|
""" |
|
|
|
corr = df.corr( |
|
method = kwargs.get("method", "pearson"), |
|
min_periods = kwargs.get("min_periods", 1) |
|
) |
|
|
|
# * get additional keyword arguments for control |
|
round_ = kwargs.get("round", 2) |
|
_environment = kwargs.get("_environment", "terminal") |
|
|
|
# if annotation is true, then define `labels` |
|
if annot: |
|
labels = corr.applymap(lambda x : str(round(x, round_)) if abs(x) >= annot_thresh else "") |
|
else: |
|
# annotation is not required in heatmap, so set default `None` for all labels |
|
labels = None |
|
|
|
# define masking for heatmap |
|
# https://stackoverflow.com/q/57414771/6623589 |
|
if tri == "all": |
|
mask = None # default, show both upper and lower triangle |
|
elif tri == "upper": |
|
# https://numpy.org/doc/stable/reference/generated/numpy.triu.html |
|
mask = np.triu(corr) # show upper triangle in heatmap |
|
elif tri == "lower": |
|
# https://numpy.org/doc/stable/reference/generated/numpy.tril.html |
|
mask = np.tril(corr) # show lower triangle in heatmap |
|
else: |
|
raise ValueError(f"tri ( == {tri}) is not understood. Accepted values: all|upper|lower.") |
|
|
|
# plot the actual heatmap and set other attributes |
|
if orient == "h": |
|
# horizontally stack heatmap and barplot |
|
fig, axs = plt.subplots(nrows = 1, ncols = 2) |
|
elif orient == "v": |
|
fig, axs = plt.subplots(nrows = 2, ncols = 1) |
|
else: |
|
raise ValueError(f"orient ( == {orient}) is not understood. Accepted values: h|v.") |
|
|
|
# plot heatmap using seaborn library |
|
_ = sns.heatmap( |
|
corr, |
|
fmt = "", |
|
ax = axs[0], |
|
mask = mask, |
|
annot = labels, |
|
vmin = kwargs.get("vmin", None), |
|
vmax = kwargs.get("vmax", None), |
|
cmap = kwargs.get("cmap", None), |
|
cbar = kwargs.get("cbar", True), |
|
square = kwargs.get("square", False), |
|
linewidths = kwargs.get("linewidths", 0), |
|
linecolor = kwargs.get("linecolor", "white"), |
|
) |
|
|
|
# plot bar using defined corr_barplot() |
|
_ = corr_barplot( |
|
corr, |
|
target_column = target_column, |
|
keep_all_feature = kwargs.get("keep_all_feature", False), |
|
threshold = annot_thresh, |
|
ax = axs[1], |
|
round = round_, |
|
y_annot_pos_adjust = kwargs.get("y_annot_pos_adjust", (5e-4, -25e-3)) |
|
) |
|
|
|
if _environment == "jupyter": |
|
# https://stackoverflow.com/q/35422988/6623589 |
|
plt.close() |
|
|
|
return fig |
|
|
|
def corr_barplot( |
|
correlations : pd.DataFrame, |
|
target_column : str, |
|
keep_all_feature : bool = False, |
|
threshold : float = 0.65, |
|
**kwargs |
|
) -> plt.Figure: |
|
""" |
|
Bar Plot the Correlation Values of all Features against Target |
|
|
|
Given a target column name from the `correlations = df.corr()` |
|
the function plots a bar plot, where the y-value length is the |
|
correlation coefficient. |
|
|
|
:type correlations: object |
|
:param correlations: Correlation values, typically this is |
|
obtained by using the `df.corr()` function, |
|
and is controlled externally. |
|
|
|
:type target_column: str |
|
:param target_column: Name of the target column, must be present |
|
in both column and index of the correlation |
|
dataframe. |
|
|
|
:type keep_all_feature: bool |
|
:param keep_all_feature: Keep all the numeric features column in |
|
bar plot, or just show only those column |
|
(or feature) whose correlation is above |
|
`threshold` or below `-threshold`. |
|
Defaults to False, i.e. only essential |
|
features are displayed. |
|
|
|
:type threshold: float |
|
:param threshold: Threshold value based on which important |
|
feeatures is decided. The range of value is |
|
[0, 1] i.e. for any given value (x) the |
|
feature is important iff correlation value is |
|
greater than or equal to `x` or less than or |
|
equal to `-x`. Defaults to 0.65. |
|
|
|
Keyword Arguments |
|
----------------- |
|
The function accepts alomst all keyword arguments accepted by |
|
`df.sort_values()` and `sns.barplot()` method. Additionally, the |
|
behaviour of the plot and correlation can be controlled with the |
|
below arguments. |
|
|
|
* *ascending* (`bool`): Sort the features correlation in |
|
ascending order. Defaults to True. This parameter is passed |
|
directly to `df.sort_values()` for sorting. |
|
* *y_annot_pos_adjust* (`array-like`): A set of `(up, low)` |
|
value to adjust the annotation in bar plot. The two value |
|
is passed to `ticks` and the text is adjusted. |
|
""" |
|
|
|
# format the correlation dataframe |
|
correlations = correlations[target_column].reset_index() \ |
|
.rename(columns = {"index" : "feature"}) \ |
|
.sort_values(target_column, ascending = kwargs.get("ascending", True)) |
|
|
|
correlations = correlations[correlations.feature != target_column] |
|
if not keep_all_feature: |
|
# remove features based on threshold value |
|
correlations = correlations[~correlations[target_column].between(-threshold, threshold)] |
|
|
|
# plot the bar, but first let decide axis parameters |
|
ax = kwargs.get("ax", None) |
|
_environment = kwargs.get("_environment", "terminal") |
|
if ax: |
|
# we have axis object, no need for seperate defination |
|
# ! no return, as controlled by parent class `fig` object |
|
ax = sns.barplot( |
|
x = "feature", y = target_column, data = correlations, |
|
palette = sns.color_palette("RdYlBu", correlations.shape[0]).as_hex(), |
|
ax = ax, |
|
) |
|
else: |
|
# this function is called in standalone mode |
|
# thus, we define axis object, and figure is returned |
|
fig, ax = plt.subplots(nrows = 1, ncols = 1) |
|
ax = sns.barplot( |
|
x = "feature", y = target_column, data = correlations, |
|
palette = sns.color_palette("RdYlBu", correlations.shape[0]).as_hex() |
|
) |
|
|
|
if _environment == "jupyter": |
|
# https://stackoverflow.com/q/35422988/6623589 |
|
plt.close() |
|
|
|
# get y-annotation position adjustment value from kwargs |
|
y_pos_u, y_pos_l = kwargs.get("y_annot_pos_adjust", (0, 0)) |
|
corr_ = correlations[target_column].values.round(kwargs.get("round", 2)) |
|
for tick in range(len(ax.get_xticklabels())): |
|
y_pos = corr_[tick] + (y_pos_u if corr_[tick] >= 0 else y_pos_l) |
|
ax.text(tick, y_pos, str(corr_[tick]), ha = "center", weight = "bold") |
|
|
|
ax.set(xlabel = kwargs.get("bar_xlabel", "Feature Names")) |
|
ax.set(ylabel = kwargs.get("bar_ylabel", f"Correlation Value with `{target_column}`")) |
|
|
|
return None if ax else fig |