Skip to content
Snippets Groups Projects
utils.py 2.77 KiB
Newer Older
import sys
import os
from pathlib import Path

def define_python_interpreter(python_path=None, conda_env=None):
    conda_base = "module load Anaconda3\nconda activate {conda_env}"
    venv_base = "source {python_path}"

    if conda_env is not None:
        env =  conda_base.format(conda_env=conda_env)
    elif python_path is not None:
        parent = Path(python_path).absolute().parent
        env =  venv_base.format(python_path=parent.joinpath('activate'))
    else:
        conda_env = os.environ.get('CONDA_PREFIX')
        if conda_env:
            env =  conda_base.format(conda_env=conda_env)
        else:
            parent = Path(sys.executable).absolute().parent
            env =  venv_base.format(python_path=parent.joinpath('activate'))
    return env

class CustomHelpFormatter(argparse.MetavarTypeHelpFormatter):
    def add_arguments(self, actions):
        # Sort actions by their group title
        actions = sorted(actions, key=lambda x: x.container.title if x.container.title else '')
        super(CustomHelpFormatter, self).add_arguments(actions)

def batch_parser(
        cpus_per_task: int | None = None,
        gpus: int | None = None, 
        partition: str | None = None, 
        mem: str | None = None,
        time: str | None = '12:00:00',
        reservation: str | None = None,
        slurm_log_dir: str | Path | None = './out',
        **kwargs
) -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(add_help=False,
                                     formatter_class=CustomHelpFormatter)
    slurm = parser.add_argument_group(title='Slurm Options')
    slurm.add_argument('--batch', action='store_true', default=False,
                       help="Convert as a batch array job.")
    slurm.add_argument('-n', '--ntasks', type=int, default=1)
    slurm.add_argument('-c', '--cpus-per-task', type=int, default=cpus_per_task)
    slurm.add_argument('-g', '--gpus', type=int, default=gpus, choices=[0, 1])
    slurm.add_argument('-p', '--partition', type=str, default=partition)
    slurm.add_argument('-t', '--time', type=str, default=time)
    slurm.add_argument('-m', '--mem', type=str, default=mem)
    slurm.add_argument('--reservation',type=str,default=reservation)
    slurm.add_argument('--slurm-log-dir', type=Path, default=slurm_log_dir,
                       help='Output log directory. If the directory does not exist, it will be created automatically.')
    return parser

def setup_slurm_logs(slurm_log_dir,log_basename):
    slurm_log_dir = slurm_log_dir.absolute()
    slurm_log_dir.mkdir(exist_ok = True,parents=True,mode = 0o2770)
    out_log,err_log = [str(slurm_log_dir.joinpath(f'{log_basename}_%A_%a.out')),str(slurm_log_dir.joinpath(f'{log_basename}_%A_%a.err'))]
    slurm_logs = {'output_log':out_log,'error_log':err_log}
    return slurm_logs