Skip to content
Snippets Groups Projects
Commit 025701f0 authored by Augustin Zidek's avatar Augustin Zidek
Browse files

Use model dir flag for run_alphafold_test

PiperOrigin-RevId: 698034193
parent 65328c9c
No related branches found
No related tags found
1 merge request!1Cloned AlphaFold 3 repo into this one
......@@ -56,8 +56,8 @@ import numpy as np
_HOME_DIR = pathlib.Path(os.environ.get('HOME'))
DEFAULT_MODEL_DIR = _HOME_DIR / 'models'
DEFAULT_DB_DIR = _HOME_DIR / 'public_databases'
_DEFAULT_MODEL_DIR = _HOME_DIR / 'models'
_DEFAULT_DB_DIR = _HOME_DIR / 'public_databases'
# Input and output paths.
......@@ -77,9 +77,9 @@ _OUTPUT_DIR = flags.DEFINE_string(
'Path to a directory where the results will be saved.',
)
_MODEL_DIR = flags.DEFINE_string(
MODEL_DIR = flags.DEFINE_string(
'model_dir',
DEFAULT_MODEL_DIR.as_posix(),
_DEFAULT_MODEL_DIR.as_posix(),
'Path to the model to use for inference.',
)
......@@ -137,9 +137,9 @@ _HMMBUILD_BINARY_PATH = flags.DEFINE_string(
)
# Database paths.
_DB_DIR = flags.DEFINE_multi_string(
DB_DIR = flags.DEFINE_multi_string(
'db_dir',
(DEFAULT_DB_DIR.as_posix(),),
(_DEFAULT_DB_DIR.as_posix(),),
'Path to the directory containing the databases. Can be specified multiple'
' times to search multiple directories in order.',
)
......@@ -635,7 +635,7 @@ def main(_):
print('\n'.join(notice))
if _RUN_DATA_PIPELINE.value:
expand_path = lambda x: replace_db_dir(x, _DB_DIR.value)
expand_path = lambda x: replace_db_dir(x, DB_DIR.value)
data_pipeline_config = pipeline.DataPipelineConfig(
jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value,
nhmmer_binary_path=_NHMMER_BINARY_PATH.value,
......@@ -673,7 +673,7 @@ def main(_):
)
),
device=devices[0],
model_dir=pathlib.Path(_MODEL_DIR.value),
model_dir=pathlib.Path(MODEL_DIR.value),
)
else:
print('Skipping running model inference.')
......
......@@ -184,7 +184,7 @@ class InferenceTest(test_utils.StructureTestCase):
model_class=run_alphafold.diffusion_model.Diffuser,
config=run_alphafold.make_model_config(),
device=jax.local_devices()[0],
model_dir=run_alphafold.DEFAULT_MODEL_DIR,
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
)
def compare_golden(self, result_path: str) -> None:
......@@ -301,7 +301,7 @@ class InferenceTest(test_utils.StructureTestCase):
model_class=diffusion_model.Diffuser,
config=run_alphafold.make_model_config(),
device=jax.local_devices(backend='gpu')[0],
model_dir=pathlib.Path(run_alphafold.DEFAULT_MODEL_DIR),
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
),
output_dir='unused output dir',
)
......@@ -332,7 +332,7 @@ class InferenceTest(test_utils.StructureTestCase):
model_class=diffusion_model.Diffuser,
config=run_alphafold.make_model_config(),
device=jax.local_devices(backend='gpu')[0],
model_dir=pathlib.Path(run_alphafold.DEFAULT_MODEL_DIR),
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
),
output_dir=output_dir,
buckets=None if bucket is None else [bucket],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment