From 7452dc22e6690cb45052cae297460b3e8f21c033 Mon Sep 17 00:00:00 2001
From: Matthew K Defenderfer <mdefende@uab.edu>
Date: Fri, 10 Jan 2025 01:17:53 -0600
Subject: [PATCH] Fix how list types are specified in type hints for type
 checking

---
 src/rc_gpfs/policy/convert.py  |  4 ++--
 src/rc_gpfs/process/factory.py | 38 +++++++++++++++++-----------------
 src/rc_gpfs/process/process.py |  4 ++--
 3 files changed, 23 insertions(+), 23 deletions(-)

diff --git a/src/rc_gpfs/policy/convert.py b/src/rc_gpfs/policy/convert.py
index 1414580..1f2788b 100755
--- a/src/rc_gpfs/policy/convert.py
+++ b/src/rc_gpfs/policy/convert.py
@@ -5,7 +5,7 @@ import random
 import string
 import shutil
 from pathlib import Path
-from typing import Literal
+from typing import Literal, List
 from urllib.parse import unquote
 
 import pandas as pd
@@ -81,7 +81,7 @@ def convert(
 def hivize(
         parquet_path: str | Path, 
         hive_path: str | Path, 
-        tld: str | list[str] | None = None,
+        tld: str | List[str] | None = None,
         staging_path: str | Path | None = None,
         partition_size: str = '100MiB', 
         with_cuda: bool | Literal['infer'] = 'infer',
diff --git a/src/rc_gpfs/process/factory.py b/src/rc_gpfs/process/factory.py
index 1610f65..c608f82 100644
--- a/src/rc_gpfs/process/factory.py
+++ b/src/rc_gpfs/process/factory.py
@@ -3,7 +3,7 @@ import pandas as pd
 import dask.dataframe as dd
 import dask_cudf
 from .utils import as_timedelta
-from typing import Literal
+from typing import Literal, List
 from typeguard import typechecked
 
 __all__ = ['get_aggregator']
@@ -21,8 +21,8 @@ class Aggregator:
     def _cut(
         self,
         ser: pd.Series | cudf.Series,
-        bins: list[int | pd.Timestamp],
-        labels: list[str] | None = None,
+        bins: List[int | pd.Timestamp],
+        labels: List[str] | None = None,
         **kwargs
     ) -> pd.Series | cudf.Series:
         right = kwargs.pop('right',False)
@@ -40,10 +40,10 @@ class Aggregator:
 
     def create_timedelta_cutoffs(
         self,
-        delta_vals: int | list[int],
+        delta_vals: int | List[int],
         delta_unit: Literal['D','W','M','Y'],
         run_date: pd.Timestamp
-    ) -> list[int | pd.Timestamp]:
+    ) -> List[int | pd.Timestamp]:
         deltas = pd.Series([as_timedelta(c,delta_unit) for c in delta_vals])
         cutoffs = pd.to_datetime(run_date - deltas)
         cutoffs = (
@@ -60,9 +60,9 @@ class Aggregator:
     
     def create_timedelta_labels(
         self,
-        delta_vals: list[int],
+        delta_vals: List[int],
         delta_unit: Literal['D','W','M','Y'],
-    ) -> list[str]:
+    ) -> List[str]:
         delta_vals.sort(reverse=True)
         deltas = [f'{d}{delta_unit}' for d in delta_vals]
         labels = [f'>{deltas[0]}'] + [f'{deltas[i+1]}-{deltas[i]}' for i in range(len(deltas)-1)] + [f'<{deltas[-1]}']
@@ -82,9 +82,9 @@ class PandasAggregator(Aggregator):
     def aggregate(
         self,
         df: cudf.DataFrame,
-        col: str | list[str],
-        grps: str | list[str],
-        funcs: str | list[str]
+        col: str | List[str],
+        grps: str | List[str],
+        funcs: str | List[str]
     ) -> pd.DataFrame:
         
         df_agg = (
@@ -110,9 +110,9 @@ class CUDFAggregator(Aggregator):
     def aggregate(
         self,
         df: cudf.DataFrame,
-        col: str | list[str],
-        grps: str | list[str],
-        funcs: str | list[str]
+        col: str | List[str],
+        grps: str | List[str],
+        funcs: str | List[str]
     ) -> pd.DataFrame:
         df_agg = (
             df.groupby(grps,observed = True)[col]
@@ -135,9 +135,9 @@ class DaskAggregator(Aggregator):
     def aggregate(
         self,
         df: dd.DataFrame,
-        col: str | list[str],
-        grps: str | list[str],
-        funcs: str | list[str]
+        col: str | List[str],
+        grps: str | List[str],
+        funcs: str | List[str]
     ) -> pd.DataFrame:
         df_agg = (
             df.groupby(grps,observed = True)[col]
@@ -164,9 +164,9 @@ class DaskCUDFAggregator(Aggregator):
     def aggregate(
         self,
         df: dask_cudf.DataFrame,
-        col: str | list[str],
-        grps: str | list[str],
-        funcs: str | list[str]
+        col: str | List[str],
+        grps: str | List[str],
+        funcs: str | List[str]
     ) -> pd.DataFrame:
         df_agg = (
             df.groupby(grps,observed = True)[col]
diff --git a/src/rc_gpfs/process/process.py b/src/rc_gpfs/process/process.py
index 4c61b27..2de789f 100644
--- a/src/rc_gpfs/process/process.py
+++ b/src/rc_gpfs/process/process.py
@@ -3,7 +3,7 @@ import pandas as pd
 from ..compute import start_backend
 from .utils import extract_run_date_from_filename
 from .factory import get_aggregator
-from typing import Literal
+from typing import Literal, List
 from typeguard import typechecked
 
 __all__ = ['aggregate_gpfs_dataset']
@@ -56,7 +56,7 @@ def _check_timedelta_values(vals,unit):
 def aggregate_gpfs_dataset(
     dataset_path: str | Path,
     run_date: pd.Timestamp | None = None,
-    delta_vals: int | list[int] | None = None,
+    delta_vals: int | List[int] | None = None,
     delta_unit: Literal['D','W','M','Y'] | None = None,
     time_val: Literal['access','modify','create'] = 'access',
     report_dir: str | Path | None = None,
-- 
GitLab