Skip to content
Snippets Groups Projects
utils.py 5.73 KiB
Newer Older
import sys
import ast
import pyarrow.parquet as pq
import pynvml
import importlib
from typing import Any
from pathlib import Path

def estimate_dataset_size(path: str | Path) -> float:
    if not isinstance(path,Path):
        path = Path(path)
    
    # only read first parquet file if path leads to directory
    if path.is_dir():
        parquet_files = sorted([f for f in path.glob('*.parquet')])
    else:
        parquet_files = [path]
    
    total_uncompressed_size = 0

    # Loop through each row group in the Parquet file
    for file_path in parquet_files:
        with pq.ParquetFile(file_path) as parquet:
            for i in range(parquet.metadata.num_row_groups):
                row_group = parquet.metadata.row_group(i)
            
                # Sum up the uncompressed sizes for each row group
                total_uncompressed_size += row_group.total_byte_size
    
    return total_uncompressed_size / (1024**3)

def get_gpu_info():
    try:
        pynvml.nvmlInit()
    except Exception as e:
        print("INFO: No GPU found. Using CPU backend.")
        return [0,0]
    
    gpus = pynvml.nvmlDeviceGetCount()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
    vram = pynvml.nvmlDeviceGetMemoryInfo(handle).total/(1024**3)
    pynvml.nvmlShutdown()
    return [gpus,vram]

def import_package(package_name, alias=None):
    """
    Dynamically imports a library and adds it to the main (top-level) namespace.

    Args:
        library_name (str): The name of the library to import.
        alias (str, optional): An alias to assign to the library in the main namespace.
                               Defaults to the library name.
    """
    try:
        # Dynamically import the library or submodule
        imported_library = importlib.import_module(package_name)
        
        # Determine the alias or default name
        global_name = alias if alias else package_name
        
        # Access the main module's global namespace
        main_namespace = sys.modules['__main__'].__dict__
        
        # Add the full library or submodule to the main namespace
        main_namespace[global_name] = imported_library
        
        # Also ensure parent modules are available in the global namespace
        parts = package_name.split('.')
        for i in range(1, len(parts)):
            parent_name = '.'.join(parts[:i])
            if parent_name not in main_namespace:
                main_namespace[parts[i - 1]] = importlib.import_module(parent_name)
        
        print(f"Successfully imported '{package_name}' into the main namespace as '{global_name}'.")
    except ImportError as e:
        print(f"Error importing '{package_name}': {e}")
        raise

def parse_hook_string(hook_string: str) -> tuple[str, list[Any], dict[str | None, Any]]:
    """
    Parses a string representing a function call and extracts the function and its arguments.

    Parameters
    ----------
    hook_string : str
        A string representation of the function call 
            (e.g., "dask.config.set({'dataframe.backend': 'cudf'})").

    Returns
    -------
    tuple[str, list[Any], dict[str | None, Any]]
        - hook (str): The dotted path to the function (e.g., "dask.config.set").
        - args (list): A list of positional arguments.
        - kwargs (dict): A dictionary of keyword arguments.

    Raises
    ------
    ValueError
        If the input string cannot be parsed as a function call.
    """
    try:
        # Parse the string into an AST (Abstract Syntax Tree)
        tree = ast.parse(hook_string, mode="eval")

        # Ensure the root of the tree is a function call
        if not isinstance(tree.body, ast.Call):
            raise ValueError(f"Invalid hook string: {hook_string}")

        # Extract the function name (dotted path)
        func = tree.body.func

        def get_full_name(node):
            if isinstance(node, ast.Name):
                return node.id
            elif isinstance(node, ast.Attribute):
                return f"{get_full_name(node.value)}.{node.attr}"
            raise ValueError(f"Unsupported function structure in: {hook_string}")

        hook = get_full_name(func)

        # Extract positional arguments
        args = [ast.literal_eval(arg) for arg in tree.body.args]

        # Extract keyword arguments
        kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in tree.body.keywords}

        return hook, args, kwargs
    except Exception as e:
        raise ValueError(f"Error parsing hook string '{hook_string}': {e}")

def run_hook(hook: str, *args, **kwargs):
    """
    Executes a hook function specified by its dotted path, with optional arguments and keyword arguments.

    Parameters
    ----------
    hook : str
        The dotted path to the hook function (e.g., 'dask.config.set')
    *args : Any
        Positional arguments to pass to the hook function.
    **kwargs : Any
        Keyword arguments to pass to the hook function.
    
    Raises
    ------
    Exception 
        If the specified hook cannot be imported or executed.
    """

    try:
        # Split the hook into module and function parts
        hook_module_name, hook_func_name = hook.rsplit('.', 1)
        
        # Dynamically import the module
        hook_module = importlib.import_module(hook_module_name)
        
        # Get the function from the module
        hook_func = getattr(hook_module, hook_func_name)
        
        # Execute the function with the provided arguments and keyword arguments
        hook_func(*args, **kwargs)
        print(f"Executed hook: {hook} with args: {args} and kwargs: {kwargs}")
    except Exception as e:
        print(f"Error executing hook '{hook}': {e}")
        raise

def wrap_hook(hook: str) -> None:
    func, args, kwargs = parse_hook_string(hook)
    run_hook(func,*args,**kwargs)
    return None