import sys
import ast
import pyarrow.parquet as pq
import pynvml
import importlib
from typing import Any
from pathlib import Path

def estimate_dataset_size(path: str | Path) -> float:
    if not isinstance(path,Path):
        path = Path(path)
    
    # only read first parquet file if path leads to directory
    if path.is_dir():
        parquet_files = sorted([f for f in path.glob('*.parquet')])
    else:
        parquet_files = [path]
    
    total_uncompressed_size = 0

    # Loop through each row group in the Parquet file
    for file_path in parquet_files:
        with pq.ParquetFile(file_path) as parquet:
            for i in range(parquet.metadata.num_row_groups):
                row_group = parquet.metadata.row_group(i)
            
                # Sum up the uncompressed sizes for each row group
                total_uncompressed_size += row_group.total_byte_size
    
    return total_uncompressed_size / (1024**3)

def get_gpu_info():
    try:
        pynvml.nvmlInit()
    except Exception as e:
        print("INFO: No GPU found. Using CPU backend.")
        return [0,0]
    
    gpus = pynvml.nvmlDeviceGetCount()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
    vram = pynvml.nvmlDeviceGetMemoryInfo(handle).total/(1024**3)
    pynvml.nvmlShutdown()
    return [gpus,vram]

def import_package(package_name, alias=None):
    """
    Dynamically imports a library and adds it to the main (top-level) namespace.

    Args:
        library_name (str): The name of the library to import.
        alias (str, optional): An alias to assign to the library in the main namespace.
                               Defaults to the library name.
    """
    try:
        # Dynamically import the library or submodule
        imported_library = importlib.import_module(package_name)
        
        # Determine the alias or default name
        global_name = alias if alias else package_name
        
        # Access the main module's global namespace
        main_namespace = sys.modules['__main__'].__dict__
        
        # Add the full library or submodule to the main namespace
        main_namespace[global_name] = imported_library
        
        # Also ensure parent modules are available in the global namespace
        parts = package_name.split('.')
        for i in range(1, len(parts)):
            parent_name = '.'.join(parts[:i])
            if parent_name not in main_namespace:
                main_namespace[parts[i - 1]] = importlib.import_module(parent_name)
        
        print(f"Successfully imported '{package_name}' into the main namespace as '{global_name}'.")
    except ImportError as e:
        print(f"Error importing '{package_name}': {e}")
        raise

def parse_hook_string(hook_string: str) -> tuple[str, list[Any], dict[str | None, Any]]:
    """
    Parses a string representing a function call and extracts the function and its arguments.

    Parameters
    ----------
    hook_string : str
        A string representation of the function call 
            (e.g., "dask.config.set({'dataframe.backend': 'cudf'})").

    Returns
    -------
    tuple[str, list[Any], dict[str | None, Any]]
        - hook (str): The dotted path to the function (e.g., "dask.config.set").
        - args (list): A list of positional arguments.
        - kwargs (dict): A dictionary of keyword arguments.

    Raises
    ------
    ValueError
        If the input string cannot be parsed as a function call.
    """
    try:
        # Parse the string into an AST (Abstract Syntax Tree)
        tree = ast.parse(hook_string, mode="eval")

        # Ensure the root of the tree is a function call
        if not isinstance(tree.body, ast.Call):
            raise ValueError(f"Invalid hook string: {hook_string}")

        # Extract the function name (dotted path)
        func = tree.body.func

        def get_full_name(node):
            if isinstance(node, ast.Name):
                return node.id
            elif isinstance(node, ast.Attribute):
                return f"{get_full_name(node.value)}.{node.attr}"
            raise ValueError(f"Unsupported function structure in: {hook_string}")

        hook = get_full_name(func)

        # Extract positional arguments
        args = [ast.literal_eval(arg) for arg in tree.body.args]

        # Extract keyword arguments
        kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in tree.body.keywords}

        return hook, args, kwargs
    except Exception as e:
        raise ValueError(f"Error parsing hook string '{hook_string}': {e}")

def run_hook(hook: str, *args, **kwargs):
    """
    Executes a hook function specified by its dotted path, with optional arguments and keyword arguments.

    Parameters
    ----------
    hook : str
        The dotted path to the hook function (e.g., 'dask.config.set')
    *args : Any
        Positional arguments to pass to the hook function.
    **kwargs : Any
        Keyword arguments to pass to the hook function.
    
    Raises
    ------
    Exception 
        If the specified hook cannot be imported or executed.
    """

    try:
        # Split the hook into module and function parts
        hook_module_name, hook_func_name = hook.rsplit('.', 1)
        
        # Dynamically import the module
        hook_module = importlib.import_module(hook_module_name)
        
        # Get the function from the module
        hook_func = getattr(hook_module, hook_func_name)
        
        # Execute the function with the provided arguments and keyword arguments
        hook_func(*args, **kwargs)
        print(f"Executed hook: {hook} with args: {args} and kwargs: {kwargs}")
    except Exception as e:
        print(f"Error executing hook '{hook}': {e}")
        raise

def wrap_hook(hook: str) -> None:
    func, args, kwargs = parse_hook_string(hook)
    run_hook(func,*args,**kwargs)
    return None