Skip to content
Snippets Groups Projects

Fix how list types are specified in type hints for type checking

2 files
+ 20
20
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -3,7 +3,7 @@ import pandas as pd
import dask.dataframe as dd
import dask_cudf
from .utils import as_timedelta
from typing import Literal
from typing import Literal, List
from typeguard import typechecked
__all__ = ['get_aggregator']
@@ -21,8 +21,8 @@ class Aggregator:
def _cut(
self,
ser: pd.Series | cudf.Series,
bins: list[int | pd.Timestamp],
labels: list[str] | None = None,
bins: List[int | pd.Timestamp],
labels: List[str] | None = None,
**kwargs
) -> pd.Series | cudf.Series:
right = kwargs.pop('right',False)
@@ -40,10 +40,10 @@ class Aggregator:
def create_timedelta_cutoffs(
self,
delta_vals: int | list[int],
delta_vals: int | List[int],
delta_unit: Literal['D','W','M','Y'],
run_date: pd.Timestamp
) -> list[int | pd.Timestamp]:
) -> List[int | pd.Timestamp]:
deltas = pd.Series([as_timedelta(c,delta_unit) for c in delta_vals])
cutoffs = pd.to_datetime(run_date - deltas)
cutoffs = (
@@ -60,9 +60,9 @@ class Aggregator:
def create_timedelta_labels(
self,
delta_vals: list[int],
delta_vals: List[int],
delta_unit: Literal['D','W','M','Y'],
) -> list[str]:
) -> List[str]:
delta_vals.sort(reverse=True)
deltas = [f'{d}{delta_unit}' for d in delta_vals]
labels = [f'>{deltas[0]}'] + [f'{deltas[i+1]}-{deltas[i]}' for i in range(len(deltas)-1)] + [f'<{deltas[-1]}']
@@ -82,9 +82,9 @@ class PandasAggregator(Aggregator):
def aggregate(
self,
df: cudf.DataFrame,
col: str | list[str],
grps: str | list[str],
funcs: str | list[str]
col: str | List[str],
grps: str | List[str],
funcs: str | List[str]
) -> pd.DataFrame:
df_agg = (
@@ -110,9 +110,9 @@ class CUDFAggregator(Aggregator):
def aggregate(
self,
df: cudf.DataFrame,
col: str | list[str],
grps: str | list[str],
funcs: str | list[str]
col: str | List[str],
grps: str | List[str],
funcs: str | List[str]
) -> pd.DataFrame:
df_agg = (
df.groupby(grps,observed = True)[col]
@@ -135,9 +135,9 @@ class DaskAggregator(Aggregator):
def aggregate(
self,
df: dd.DataFrame,
col: str | list[str],
grps: str | list[str],
funcs: str | list[str]
col: str | List[str],
grps: str | List[str],
funcs: str | List[str]
) -> pd.DataFrame:
df_agg = (
df.groupby(grps,observed = True)[col]
@@ -164,9 +164,9 @@ class DaskCUDFAggregator(Aggregator):
def aggregate(
self,
df: dask_cudf.DataFrame,
col: str | list[str],
grps: str | list[str],
funcs: str | list[str]
col: str | List[str],
grps: str | List[str],
funcs: str | List[str]
) -> pd.DataFrame:
df_agg = (
df.groupby(grps,observed = True)[col]
Loading