Source code for altair_recipes.common

"""Mini-library for all the other modules."""
import altair as alt
from boltons.iterutils import remap
from logging import warning
import numpy as np
import pandas as pd
from toolz.dicttoolz import keyfilter, valfilter


def default(*args):
    """Return the first not None of the arguments.

    Parameters
    ----------
    *args : Any
        Any number of arguments.

    Returns
    -------
    Any
        One of the arguments.

    """
    return next(x for x in args if x is not None)


# testing
def viz_reg_test(test_f):
    """Decorate recipe tests.

    Transforms a function into a regression test. Also saves chart in html for
    visual inspection in the test directory, named after the file and function
    of the test. If invoked with None as sole argument, the decorated function
    will just produce the chart, disabling the regression machinery, again for
    manual inspection.

    Parameters
    ----------
    test_f : function
        Simple chart-generating argumentless function, pytest-style

    Returns
    -------
    function
        The decorated function.

    """

    def fun(regtest):
        with alt.data_transformers.enable(consolidate_datasets=False):
            np.random.seed(seed=0)
            plot = test_f()
            if regtest is not None:
                regtest.write(
                    alt.Chart.from_dict(round_floats(plot.to_dict(), 13)).to_json()
                )
                plot.save(
                    test_f.__code__.co_filename + "_" + test_f.__qualname__ + ".html"
                )
            return plot

    test_f.__doc__ = (
        (
            test_f.__doc__
            or "Test for function {test_f}".format(test_f=test_f.__qualname__)
        )
        + """
    Parameters
    ----------
    Pass a single unnamed argument equal None to manually  execute outside regression testing. In that case it returns a chart.
    """
    )
    return fun


# collections


def check_distinct(data, col, group=None):
    if group is None:
        x = data[col]
        return x.size == x.nunique()
    else:
        x = data.groupby(group)[col]
        return all(x.size() == x.nunique())

    x = data.groupby(group)[col] if group is not None else data[col]
    return all((x.size() if group is not None else x.size) == x.nunique())


def warn_not_distinct(data, col, group=None):
    if not check_distinct(data, col, group):
        warning("The relation to plot is not a function")


def choose_kwargs(from_, which):
    """Choose entries for a dictionary with key in `which` and not None value."""
    return keyfilter(lambda x: x in which, valfilter(lambda x: x is not None, from_))


def round_floats(a_dict, precision):
    """Find all the floats in `a_dict` (recursive) and round them to `precision`."""
    return remap(
        a_dict,
        lambda p, k, v: (k, round(v, precision)) if isinstance(v, float) else (k, v),
    )


def ndistinct(data, column):
    """Return number of distinct elements in data[column]."""
    return len(data[column].unique())


def col_cardinality(data, column, condition=None, default=1):
    """Return number of distinct elements in `data[column]` if `condition` is True (defaults to `column` being different from None), otherwise return `default`."""
    if condition is None:
        condition = column is not None
    return ndistinct(data, column) if condition else default


def gather(data, key, value, columns):
    """Convert wide format data frame to long format.

    Do so while concatenating selected columns into one and using a new column
    to track their origin.

    Parameters
    ----------
    data : pandas DataFrame
        The data to operate on.
    key : str
        The name of the column tracking the origin of that record.
    value : type
        The name of the column holding all the values previously in a number of
        columns.
    columns : list of str
        The names of the columns to reduce to a single column.

    Returns
    -------
    pandas.DataFrame
        A data frame with a reduced number of columns but the same information.

    """
    return pd.melt(
        data,
        id_vars=[col for col in data.columns if col not in columns],
        value_vars=columns,
        var_name=key,
        value_name=value,
    )


# TODO: this doesn't cover multiscatterplot which is a multivariate viz but does
# require the data in wide format, this only converts to long (gather)
# To include multiscatterplot one would need an index column or set thereof


def multivariate_preprocess(data, columns, group_by):
    """Preprocess data for multivariate graphs.

    Converts to data frame, then turns to long format.

    Parameters
    ----------
    data : pandas.DataFrame
        The data to be processed.
    columns : list of str
        The columns that need to be gathered.
    group_by : str
        The column indicating which vaiable a record refers to (long format).

    Returns
    -------
    (pandas.DataFrame, str, str)
        A tuple with the data in long format, the name of the column indicating
        the variable and the name of of the column holding the values.

    """
    assert (
        type(columns) == str or len(columns) == 1 or group_by is None
    ), "Wide or long format but not both"
    if group_by is None:  # convert wide to long
        key = default(data.columns.name, "variable")
        value = "value"
        data = gather(data, key=key, value=value, columns=columns)
    else:
        key = group_by
        value = columns if type(columns) is str else columns[0]
    return data, key, value


# constant luminosity and chroma scales

hue_scale_light = alt.Scale(type="linear", range=["#F271B8", "#00B4D7"])
hue_scale_dark = alt.Scale(type="linear", range=["#C10083", "#0080A7"])

# chart combinators


[docs]def layer(*layers, **kwargs): """Layer charts: a drop in replacement for altair.layer that does a deepcopy of the layers to avoid side-effects and lifts identical datasets one level down to top level.""" layers = [l.copy() for l in layers] data = layers[0].data if all(map(lambda l: data.equals(l.data), layers)): layered = alt.layer(*layers, **kwargs, data=data) for l in layered.layer: del l._kwds["data"] else: layered = alt.layer(*layers, **kwargs) return layered