Skip to content
Snippets Groups Projects
factory.py 5.72 KiB
Newer Older
import cudf
import pandas as pd
import dask.dataframe as dd
import dask_cudf
from .utils import as_timedelta
from typing import Literal
from typeguard import typechecked

__all__ = ['get_aggregator']

# ENH: In the future, probably need to wrap the manager and backend type into a class. That class would contain the 
# read_parquet function instead of putting it in the aggregation classes. This would separate everything out more 
# sensibly

@typechecked
class Aggregator:
    def __init__(self):
        self.backend = None
        self.cuda = None
    
    def _cut(
        self,
        ser: pd.Series | cudf.Series,
        bins: list[int | pd.Timestamp],
        labels: list[str] | None = None,
        **kwargs
    ) -> pd.Series | cudf.Series:
        right = kwargs.pop('right',False)
               
        if self.cuda:
            func = cudf.cut
            ser = ser.astype('int64')
        else:
            func = pd.cut

        grps = func(ser,bins=bins,labels=labels,right=right,**kwargs)
        if labels is not None:
            grps = grps.cat.reorder_categories(labels[::-1], ordered = True)
        return grps

    def create_timedelta_cutoffs(
        self,
        delta_vals: int | list[int],
        delta_unit: Literal['D','W','M','Y'],
        run_date: pd.Timestamp
    ) -> list[int | pd.Timestamp]:
        deltas = pd.Series([as_timedelta(c,delta_unit) for c in delta_vals])
        cutoffs = pd.to_datetime(run_date - deltas)
        cutoffs = (
            pd.concat(
                [
                    cutoffs,
                    pd.Series([pd.to_datetime('today'),pd.to_datetime('1970-01-01')])
                ]
            )
            .sort_values()
        )

        return cutoffs.astype('int64').to_list() if self.cuda else cutoffs.to_list()
    
    def create_timedelta_labels(
        self,
        delta_vals: list[int],
        delta_unit: Literal['D','W','M','Y'],
    ) -> list[str]:
        delta_vals.sort(reverse=True)
        deltas = [f'{d}{delta_unit}' for d in delta_vals]
        labels = [f'>{deltas[0]}'] + [f'{deltas[i+1]}-{deltas[i]}' for i in range(len(deltas)-1)] + [f'<{deltas[-1]}']
        return labels

class PandasAggregator(Aggregator):
    def __init__(self):
        self.backend = 'pandas'
        self.cuda = False

    def read_parquet(self,dataset_path,**kwargs) -> pd.DataFrame:
        return pd.read_parquet(dataset_path,**kwargs)
        
    def cut_dt(self,series,*args,**kwargs) -> pd.Series:
        return self._cut(series,*args,**kwargs)

    def aggregate(
        self,
        df: cudf.DataFrame,
        col: str | list[str],
        grps: str | list[str],
        funcs: str | list[str]
    ) -> pd.DataFrame:
        
        df_agg = (
            df.groupby(grps,observed = True)[col]
            .agg(funcs)
            .sort_index(level=[0,1])
            .reset_index()
        )
        return df_agg

    
class CUDFAggregator(Aggregator):
    def __init__(self):
        self.backend = 'cudf'
        self.cuda = True

    def read_parquet(self,dataset_path,**kwargs) -> cudf.DataFrame:
        return cudf.read_parquet(dataset_path,**kwargs)
    
    def cut_dt(self,series,*args,**kwargs) -> pd.Series:
        return self._cut(series,*args,**kwargs)
    
    def aggregate(
        self,
        df: cudf.DataFrame,
        col: str | list[str],
        grps: str | list[str],
        funcs: str | list[str]
    ) -> pd.DataFrame:
        df_agg = (
            df.groupby(grps,observed = True)[col]
            .agg(funcs)
            .sort_index(level=[0,1])
            .to_pandas()
            .reset_index()
        )
        return df_agg


class DaskAggregator(Aggregator):
    def __init__(self):
        self.backend = 'dask'
        self.cuda = False

    def cut_dt(self,series,*args,**kwargs) -> cudf.Series:
        return series.map_partitions(self._cut,*args,**kwargs)

    def aggregate(
        self,
        df: dd.DataFrame,
        col: str | list[str],
        grps: str | list[str],
        funcs: str | list[str]
    ) -> pd.DataFrame:
        df_agg = (
            df.groupby(grps,observed = True)[col]
            .agg(funcs)
            .compute()
            .sort_index(level=[0,1])
            .reset_index()
        )
        return df_agg
    
    def read_parquet(self,dataset_path,**kwargs) -> dd.DataFrame:
        split_row_groups = kwargs.pop('split_row_groups',False)
        return dd.read_parquet(dataset_path,split_row_groups=split_row_groups,**kwargs)


class DaskCUDFAggregator(Aggregator):
    def __init__(self):
        self.backend = 'dask_cuda'
        self.cuda = True

    def cut_dt(self,series,*args,**kwargs) -> dask_cudf.Series:
        return series.map_partitions(self._cut,*args,**kwargs)
    
    def aggregate(
        self,
        df: dask_cudf.DataFrame,
        col: str | list[str],
        grps: str | list[str],
        funcs: str | list[str]
    ) -> pd.DataFrame:
        df_agg = (
            df.groupby(grps,observed = True)[col]
            .agg(funcs)
            .compute()
            .sort_index(level=[0,1])
            .to_pandas()
            .reset_index()
        )
        return df_agg
    
    def read_parquet(self,dataset_path,**kwargs) -> dd.DataFrame:
        split_row_groups = kwargs.pop('split_row_groups',False)
        return dd.read_parquet(dataset_path,split_row_groups=split_row_groups,**kwargs)
    


def get_aggregator(backend) -> PandasAggregator | CUDFAggregator | DaskAggregator | DaskCUDFAggregator:
    match backend:
        case 'pandas':
            return PandasAggregator()
        case 'cudf':
            return CUDFAggregator()
        case 'dask':
            return DaskAggregator()
        case 'dask_cuda':
            return DaskCUDFAggregator()
        case _:
            raise ValueError(f"Unsupported backend: {backend}")