Newer
Older
Matthew K Defenderfer
committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import sys
import ast
import pyarrow.parquet as pq
import pynvml
import importlib
from typing import Any
from pathlib import Path
def estimate_dataset_size(path: str | Path) -> float:
if not isinstance(path,Path):
path = Path(path)
# only read first parquet file if path leads to directory
if path.is_dir():
parquet_files = sorted([f for f in path.glob('*.parquet')])
else:
parquet_files = [path]
total_uncompressed_size = 0
# Loop through each row group in the Parquet file
for file_path in parquet_files:
with pq.ParquetFile(file_path) as parquet:
for i in range(parquet.metadata.num_row_groups):
row_group = parquet.metadata.row_group(i)
# Sum up the uncompressed sizes for each row group
total_uncompressed_size += row_group.total_byte_size
return total_uncompressed_size / (1024**3)
def get_gpu_info():
try:
pynvml.nvmlInit()
except Exception as e:
print("INFO: No GPU found. Using CPU backend.")
return [0,0]
gpus = pynvml.nvmlDeviceGetCount()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
vram = pynvml.nvmlDeviceGetMemoryInfo(handle).total/(1024**3)
pynvml.nvmlShutdown()
return [gpus,vram]
def import_package(package_name, alias=None):
"""
Dynamically imports a library and adds it to the main (top-level) namespace.
Args:
library_name (str): The name of the library to import.
alias (str, optional): An alias to assign to the library in the main namespace.
Defaults to the library name.
"""
try:
# Dynamically import the library or submodule
imported_library = importlib.import_module(package_name)
# Determine the alias or default name
global_name = alias if alias else package_name
# Access the main module's global namespace
main_namespace = sys.modules['__main__'].__dict__
# Add the full library or submodule to the main namespace
main_namespace[global_name] = imported_library
# Also ensure parent modules are available in the global namespace
parts = package_name.split('.')
for i in range(1, len(parts)):
parent_name = '.'.join(parts[:i])
if parent_name not in main_namespace:
main_namespace[parts[i - 1]] = importlib.import_module(parent_name)
print(f"Successfully imported '{package_name}' into the main namespace as '{global_name}'.")
except ImportError as e:
print(f"Error importing '{package_name}': {e}")
raise
def parse_hook_string(hook_string: str) -> tuple[str, list[Any], dict[str | None, Any]]:
"""
Parses a string representing a function call and extracts the function and its arguments.
Parameters
----------
hook_string : str
A string representation of the function call
(e.g., "dask.config.set({'dataframe.backend': 'cudf'})").
Returns
-------
tuple[str, list[Any], dict[str | None, Any]]
- hook (str): The dotted path to the function (e.g., "dask.config.set").
- args (list): A list of positional arguments.
- kwargs (dict): A dictionary of keyword arguments.
Raises
------
ValueError
If the input string cannot be parsed as a function call.
"""
try:
# Parse the string into an AST (Abstract Syntax Tree)
tree = ast.parse(hook_string, mode="eval")
# Ensure the root of the tree is a function call
if not isinstance(tree.body, ast.Call):
raise ValueError(f"Invalid hook string: {hook_string}")
# Extract the function name (dotted path)
func = tree.body.func
def get_full_name(node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return f"{get_full_name(node.value)}.{node.attr}"
raise ValueError(f"Unsupported function structure in: {hook_string}")
hook = get_full_name(func)
# Extract positional arguments
args = [ast.literal_eval(arg) for arg in tree.body.args]
# Extract keyword arguments
kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in tree.body.keywords}
return hook, args, kwargs
except Exception as e:
raise ValueError(f"Error parsing hook string '{hook_string}': {e}")
def run_hook(hook: str, *args, **kwargs):
"""
Executes a hook function specified by its dotted path, with optional arguments and keyword arguments.
Parameters
----------
hook : str
The dotted path to the hook function (e.g., 'dask.config.set')
*args : Any
Positional arguments to pass to the hook function.
**kwargs : Any
Keyword arguments to pass to the hook function.
Raises
------
Exception
If the specified hook cannot be imported or executed.
"""
try:
# Split the hook into module and function parts
hook_module_name, hook_func_name = hook.rsplit('.', 1)
# Dynamically import the module
hook_module = importlib.import_module(hook_module_name)
# Get the function from the module
hook_func = getattr(hook_module, hook_func_name)
# Execute the function with the provided arguments and keyword arguments
hook_func(*args, **kwargs)
print(f"Executed hook: {hook} with args: {args} and kwargs: {kwargs}")
except Exception as e:
print(f"Error executing hook '{hook}': {e}")
raise
def wrap_hook(hook: str) -> None:
func, args, kwargs = parse_hook_string(hook)
run_hook(func,*args,**kwargs)
return None