From 025701f02fe745bd21df4f5f8196e2a6b621b85d Mon Sep 17 00:00:00 2001 From: Augustin Zidek <augustinzidek@google.com> Date: Wed, 20 Nov 2024 14:15:10 +0000 Subject: [PATCH] Use model dir flag for run_alphafold_test PiperOrigin-RevId: 698034193 --- run_alphafold.py | 16 ++++++++-------- run_alphafold_test.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/run_alphafold.py b/run_alphafold.py index 4480afc..0c0230a 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -56,8 +56,8 @@ import numpy as np _HOME_DIR = pathlib.Path(os.environ.get('HOME')) -DEFAULT_MODEL_DIR = _HOME_DIR / 'models' -DEFAULT_DB_DIR = _HOME_DIR / 'public_databases' +_DEFAULT_MODEL_DIR = _HOME_DIR / 'models' +_DEFAULT_DB_DIR = _HOME_DIR / 'public_databases' # Input and output paths. @@ -77,9 +77,9 @@ _OUTPUT_DIR = flags.DEFINE_string( 'Path to a directory where the results will be saved.', ) -_MODEL_DIR = flags.DEFINE_string( +MODEL_DIR = flags.DEFINE_string( 'model_dir', - DEFAULT_MODEL_DIR.as_posix(), + _DEFAULT_MODEL_DIR.as_posix(), 'Path to the model to use for inference.', ) @@ -137,9 +137,9 @@ _HMMBUILD_BINARY_PATH = flags.DEFINE_string( ) # Database paths. -_DB_DIR = flags.DEFINE_multi_string( +DB_DIR = flags.DEFINE_multi_string( 'db_dir', - (DEFAULT_DB_DIR.as_posix(),), + (_DEFAULT_DB_DIR.as_posix(),), 'Path to the directory containing the databases. Can be specified multiple' ' times to search multiple directories in order.', ) @@ -635,7 +635,7 @@ def main(_): print('\n'.join(notice)) if _RUN_DATA_PIPELINE.value: - expand_path = lambda x: replace_db_dir(x, _DB_DIR.value) + expand_path = lambda x: replace_db_dir(x, DB_DIR.value) data_pipeline_config = pipeline.DataPipelineConfig( jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value, nhmmer_binary_path=_NHMMER_BINARY_PATH.value, @@ -673,7 +673,7 @@ def main(_): ) ), device=devices[0], - model_dir=pathlib.Path(_MODEL_DIR.value), + model_dir=pathlib.Path(MODEL_DIR.value), ) else: print('Skipping running model inference.') diff --git a/run_alphafold_test.py b/run_alphafold_test.py index b279b9e..28f376a 100644 --- a/run_alphafold_test.py +++ b/run_alphafold_test.py @@ -184,7 +184,7 @@ class InferenceTest(test_utils.StructureTestCase): model_class=run_alphafold.diffusion_model.Diffuser, config=run_alphafold.make_model_config(), device=jax.local_devices()[0], - model_dir=run_alphafold.DEFAULT_MODEL_DIR, + model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value), ) def compare_golden(self, result_path: str) -> None: @@ -301,7 +301,7 @@ class InferenceTest(test_utils.StructureTestCase): model_class=diffusion_model.Diffuser, config=run_alphafold.make_model_config(), device=jax.local_devices(backend='gpu')[0], - model_dir=pathlib.Path(run_alphafold.DEFAULT_MODEL_DIR), + model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value), ), output_dir='unused output dir', ) @@ -332,7 +332,7 @@ class InferenceTest(test_utils.StructureTestCase): model_class=diffusion_model.Diffuser, config=run_alphafold.make_model_config(), device=jax.local_devices(backend='gpu')[0], - model_dir=pathlib.Path(run_alphafold.DEFAULT_MODEL_DIR), + model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value), ), output_dir=output_dir, buckets=None if bucket is None else [bucket], -- GitLab