From 7452dc22e6690cb45052cae297460b3e8f21c033 Mon Sep 17 00:00:00 2001 From: Matthew K Defenderfer <mdefende@uab.edu> Date: Fri, 10 Jan 2025 01:17:53 -0600 Subject: [PATCH] Fix how list types are specified in type hints for type checking --- src/rc_gpfs/policy/convert.py | 4 ++-- src/rc_gpfs/process/factory.py | 38 +++++++++++++++++----------------- src/rc_gpfs/process/process.py | 4 ++-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/rc_gpfs/policy/convert.py b/src/rc_gpfs/policy/convert.py index 1414580..1f2788b 100755 --- a/src/rc_gpfs/policy/convert.py +++ b/src/rc_gpfs/policy/convert.py @@ -5,7 +5,7 @@ import random import string import shutil from pathlib import Path -from typing import Literal +from typing import Literal, List from urllib.parse import unquote import pandas as pd @@ -81,7 +81,7 @@ def convert( def hivize( parquet_path: str | Path, hive_path: str | Path, - tld: str | list[str] | None = None, + tld: str | List[str] | None = None, staging_path: str | Path | None = None, partition_size: str = '100MiB', with_cuda: bool | Literal['infer'] = 'infer', diff --git a/src/rc_gpfs/process/factory.py b/src/rc_gpfs/process/factory.py index 1610f65..c608f82 100644 --- a/src/rc_gpfs/process/factory.py +++ b/src/rc_gpfs/process/factory.py @@ -3,7 +3,7 @@ import pandas as pd import dask.dataframe as dd import dask_cudf from .utils import as_timedelta -from typing import Literal +from typing import Literal, List from typeguard import typechecked __all__ = ['get_aggregator'] @@ -21,8 +21,8 @@ class Aggregator: def _cut( self, ser: pd.Series | cudf.Series, - bins: list[int | pd.Timestamp], - labels: list[str] | None = None, + bins: List[int | pd.Timestamp], + labels: List[str] | None = None, **kwargs ) -> pd.Series | cudf.Series: right = kwargs.pop('right',False) @@ -40,10 +40,10 @@ class Aggregator: def create_timedelta_cutoffs( self, - delta_vals: int | list[int], + delta_vals: int | List[int], delta_unit: Literal['D','W','M','Y'], run_date: pd.Timestamp - ) -> list[int | pd.Timestamp]: + ) -> List[int | pd.Timestamp]: deltas = pd.Series([as_timedelta(c,delta_unit) for c in delta_vals]) cutoffs = pd.to_datetime(run_date - deltas) cutoffs = ( @@ -60,9 +60,9 @@ class Aggregator: def create_timedelta_labels( self, - delta_vals: list[int], + delta_vals: List[int], delta_unit: Literal['D','W','M','Y'], - ) -> list[str]: + ) -> List[str]: delta_vals.sort(reverse=True) deltas = [f'{d}{delta_unit}' for d in delta_vals] labels = [f'>{deltas[0]}'] + [f'{deltas[i+1]}-{deltas[i]}' for i in range(len(deltas)-1)] + [f'<{deltas[-1]}'] @@ -82,9 +82,9 @@ class PandasAggregator(Aggregator): def aggregate( self, df: cudf.DataFrame, - col: str | list[str], - grps: str | list[str], - funcs: str | list[str] + col: str | List[str], + grps: str | List[str], + funcs: str | List[str] ) -> pd.DataFrame: df_agg = ( @@ -110,9 +110,9 @@ class CUDFAggregator(Aggregator): def aggregate( self, df: cudf.DataFrame, - col: str | list[str], - grps: str | list[str], - funcs: str | list[str] + col: str | List[str], + grps: str | List[str], + funcs: str | List[str] ) -> pd.DataFrame: df_agg = ( df.groupby(grps,observed = True)[col] @@ -135,9 +135,9 @@ class DaskAggregator(Aggregator): def aggregate( self, df: dd.DataFrame, - col: str | list[str], - grps: str | list[str], - funcs: str | list[str] + col: str | List[str], + grps: str | List[str], + funcs: str | List[str] ) -> pd.DataFrame: df_agg = ( df.groupby(grps,observed = True)[col] @@ -164,9 +164,9 @@ class DaskCUDFAggregator(Aggregator): def aggregate( self, df: dask_cudf.DataFrame, - col: str | list[str], - grps: str | list[str], - funcs: str | list[str] + col: str | List[str], + grps: str | List[str], + funcs: str | List[str] ) -> pd.DataFrame: df_agg = ( df.groupby(grps,observed = True)[col] diff --git a/src/rc_gpfs/process/process.py b/src/rc_gpfs/process/process.py index 4c61b27..2de789f 100644 --- a/src/rc_gpfs/process/process.py +++ b/src/rc_gpfs/process/process.py @@ -3,7 +3,7 @@ import pandas as pd from ..compute import start_backend from .utils import extract_run_date_from_filename from .factory import get_aggregator -from typing import Literal +from typing import Literal, List from typeguard import typechecked __all__ = ['aggregate_gpfs_dataset'] @@ -56,7 +56,7 @@ def _check_timedelta_values(vals,unit): def aggregate_gpfs_dataset( dataset_path: str | Path, run_date: pd.Timestamp | None = None, - delta_vals: int | list[int] | None = None, + delta_vals: int | List[int] | None = None, delta_unit: Literal['D','W','M','Y'] | None = None, time_val: Literal['access','modify','create'] = 'access', report_dir: str | Path | None = None, -- GitLab