Source code for trelliscope.facets

"""Create dataframes with figure facets from dataframes containing input data.

Each facet is a figure belonging to a group of data.
"""

from typing import Any, Callable

import pandas as pd
import plotly.graph_objects as go


[docs] def facet_panels( df: pd.DataFrame, panel_column_name: str, facet_columns: list, plot_function: Callable[[pd.DataFrame, ...], go.Figure], params: dict[str, Any], ) -> pd.DataFrame: """Create figure facets from a dataframe by applying a plot function. Args: df: Dataframe input containing data for plotting and grouping. panel_column_name: The names of the output column containing the figure objects. facet_columns: Columns to group data by, creating a figure for each group. plot_function: Plotting function to apply to each group, taking a pd.DataFrame as the first argument. params: Additional keyword argument passed to the ``plot_function``. Returns: A dataframe with 1 column with the name of ``panel_column_name`` containing figure objects, The dataframe will have a multi-level index, with the level names given by ``facet_columns``. Examples: .. code-block:: python from trelliscope.facets import facet_panels import plotly.graph_objs as go import pandas as pd df = pd.DataFrame( [ ("Belgium","Europe",1977,9821800), ("Belgium","Europe",1982,9856303), ("Belgium","Europe",1987,9870200), ("Belgium","Europe",1992,10045622), ("Belgium","Europe",1997,10199787), ("Belgium","Europe",2002,10311970), ("Belgium","Europe",2007,10392226), ("Singapore","Asia",1977,2325300), ("Singapore","Asia",1982,2651869), ("Singapore","Asia",1987,2794552), ("Singapore","Asia",1992,3235865), ("Singapore","Asia",1997,3802309), ("Singapore","Asia",2002,4197776) ], columns=["Country", "Continent", "Year", "Population"] ) def plot_population(df: pd.DataFrame, **kwargs) -> go.Figure: return go.Figure(go.Scatter(x=df["Year"], y=df["Population"])) df_facets = facet_panels( df=df, panel_column_name="panel", facet_columns=["Country", "Continent"], plot_function=plot_population, params={"layout": {"title": "Population over time"}} ) """ result_df = df.groupby(facet_columns).apply( lambda mini_df: plot_function(mini_df, **params) ) result_df = result_df.to_frame(name=panel_column_name) return result_df