Skip to content
Snippets Groups Projects
core.py 3.47 KiB
Newer Older
import os
import re
import subprocess
from pathlib import Path
from typing import List, Literal, Tuple
import polars as pl
import pyarrow.parquet as pq
import numpy as np

from .units import as_bytes, convert_si, create_size_bin_labels
from .datetime import *

def parse_scontrol():
    job_id = os.getenv('SLURM_JOB_ID')

    command = f"scontrol show job {job_id} | grep TRES="
    result = subprocess.run(command, shell=True, capture_output=True, text=True).stdout.strip()

    tres_pattern=r'.*cpu=(?P<cores>[\d]+),mem=(?P<mem>[\d]+[KMGT]?).*'
    cores,mem = re.search(tres_pattern,result).groupdict().values()
    
    cores = int(cores)
    mem = convert_si(mem,to_unit='G',use_binary=True)
    return [cores,mem]

def as_path(s: str | Path) -> Path:
    if not isinstance(s,Path):
        s = Path(s)
    return s

def prep_size_distribution(
        size_bins: int | str | List[int | str] = ['4 kiB','4 MiB','1 GiB','10 GiB','100 GiB','1 TiB'],
        **kwargs
) -> Tuple[List[int],List[str]]: 
    if not isinstance(size_bins,list):
        size_bins = [size_bins]

    size_bins = [as_bytes(s) if isinstance(s,str) else s for s in size_bins]

    size_bins = list(set(size_bins))
    size_bins.sort() # Sorts and removes any duplicates
    size_bins = [s for s in size_bins if s > 0] # Removes 0, as it will be implicit as the left-most break point
    
    size_labels = create_size_bin_labels(size_bins)

    return size_bins,size_labels

def calculate_size_distribution(
        sizes: pl.Series, 
        size_bins: int | str | List[int | str] = ['4 kiB','4 MiB','1 GiB','10 GiB','100 GiB','1 TiB'],
        **kwargs
) -> pl.Series:
    
    size_bins,size_labels = prep_size_distribution(size_bins)
    size_grps = (
        sizes
        .cut(
            breaks=size_bins,
            labels=size_labels,
            **kwargs
        )
        .cast(pl.String)
        .cast(pl.Enum(size_labels))
    )

    return size_grps

def prep_age_distribution(
        acq: str | np.datetime64,
        age_breakpoints: int | List[int],
        time_unit: Literal['D','W']
) -> Tuple[List[np.datetime64],List[str]]:
    if not isinstance(age_breakpoints,list):
        age_breakpoints = [age_breakpoints]
    else:
        age_breakpoints = list(set(age_breakpoints))
    
    age_breakpoints.sort()
    age_breakpoints = [t for t in age_breakpoints if t > 0]

    # Create age bin labels before converting to duration for easier parsing
    age_labels = create_timedelta_labels(age_breakpoints,time_unit)

    # # Create age bins by subtracting the number of days from the date
    age_breakpoints = create_timedelta_breakpoints(as_datetime(acq),age_breakpoints,time_unit)

    return age_breakpoints,age_labels

def calculate_age_distribution(
        timestamps: pl.Series,
        acq: str | np.datetime64,
        age_breakpoints: List[ int ] = [30,60,90,180],
        time_unit: Literal['D','W'] = 'D',
        **kwargs
) -> pl.Series:
    
    age_breakpoints, age_labels = prep_age_distribution(acq, age_breakpoints, time_unit)
    
    age_grps = (
        timestamps
        .cut(
            breaks=age_breakpoints,
            labels=age_labels,
            **kwargs
        )
        .cast(pl.String)
        .cast(pl.Enum(age_labels))
    )
    return age_grps

def get_parquet_dataset_size(parquet_path):
    tot_size = 0

    for p in parquet_path.glob("*.parquet"):
        md = pq.read_metadata(p)
        for rg in range(0, md.num_row_groups):
            tot_size += md.row_group(rg).total_byte_size
    
    return tot_size