Skip to content
Snippets Groups Projects
plotting.py 4.1 KiB
Newer Older
from plotly.graph_objects import Figure
import plotly.graph_objects as go
from plotly import colors

__all__ = ['bar_plot','pareto_chart']

def choose_appropriate_storage_unit(size,starting_unit='B'):
    if hasattr(size, "__len__"):
        size = size.max()
    
    try:
        units = ['B','kiB','MiB','GiB','TiB']
        units_b10 = ['kB','MB','GB','TB']
        if starting_unit in units_b10:
            starting_unit = starting_unit.replace('B','iB')
            # add logging message here saying the specified base 10 unit is being interpreted as base 2
        exp = units.index(starting_unit)
    except (ValueError):
        raise(f"{starting_unit} is not a valid storage unit. Choose from 'B','kB','MB','GB', or 'TB'")
    
    while ((size/1024) >= 1) & (exp < 4):
        size = size/1024
        exp += 1
    return exp,units[exp]

def _format_number(num,dec=2):
    return f"{num:,.{dec}f}"  # Format with commas and 3 decimal places

def bar_plot(df,x,y,show_legend=True,legend_labels=None,add_text=True,textposition=None,text_decimals=2,
            group_colors=colors.qualitative.Plotly,title=None,xlabel=None,ylabel=None,enable_text_hover=False) -> Figure:
    if not isinstance(y,list):
        y = [y]

    if show_legend and legend_labels is None:
        legend_labels = y
    
    textposition = textposition if add_text else None
    
    fig = go.Figure()
    for idx,c in enumerate(y):
        text = df[c].apply(_format_number,dec=text_decimals) if add_text else None
        fig.add_bar(
            x=df[x],
            y=df[c],
            text=text,
            textposition=textposition,
            name=legend_labels[idx],
            marker_color=group_colors[idx],
            uid=idx
        )

    # If plotting multiple traces, make some updates to the layout
    if len(y) > 1:
        for idx in range(len(y)):
            fig.update_traces(
                patch = {'offsetgroup':idx},
                selector = {'uid':idx}
            )
        
        fig.update_layout(
            barmode='group',  # Grouped bar chart
            bargap=0.3
        )

    fig.update_layout(
        title_text=title,
        title_x=0.5,
        title_xanchor='center',
        title_font_size = 24,
        xaxis_title=xlabel,
        yaxis_title=ylabel,
        margin=dict(t=100, b=20, l=40, r=40),
        template='plotly_white',
        hovermode=enable_text_hover
    )

    return fig

def pareto_chart(df, x, y, cumsum_col=None, show_legend=True, legend_labels=['Raw','Cumulative'], add_text=True,
                 textposition_bar=None, textposition_scatter=None, text_decimals=2, 
                 group_colors=colors.qualitative.Plotly, title=None,xlabel=None,ylabel=None,enable_text_hover=False) -> Figure:
    df_ = df.copy()
    if cumsum_col is None:
        cumsum_col = f'{y}_cumsum'
        df_[cumsum_col] = df_[y].cumsum()
    
    if show_legend and legend_labels is None:
        legend_labels = [y,cumsum_col]
    
    if add_text and textposition_bar is None:
        textposition_bar = 'outside'
    
    if add_text and textposition_scatter is None:
        textposition_scatter = 'top center'
    
    fig = go.Figure()
    bar_text = df_[y].apply(_format_number,dec=text_decimals) if add_text else None
    fig.add_bar(
        x=df_[x],
        y=df_[y],
        text=bar_text,
        textposition=textposition_bar,
        name=legend_labels[0],
        marker_color=group_colors[0]
    )

    if add_text:
        scatter_text = df_[cumsum_col].apply(_format_number,dec=text_decimals) 
        scatter_text[0] = None
    else:
        scatter_text = None
    
    fig.add_scatter(
        x=df_[x],
        y=df_[cumsum_col],
        text=scatter_text,
        textposition=textposition_scatter,
        name=legend_labels[1],
        mode='lines+markers+text',
        marker_color=group_colors[1]
    )

    fig.update_layout(
        title_text=title,
        title_x=0.5,
        title_xanchor='center',
        title_font_size = 24,
        xaxis_title=xlabel,
        yaxis_title=ylabel,
        margin=dict(t=100, b=20, l=40, r=40),
        template='plotly_white',
        hovermode=enable_text_hover
    )

    return fig