From 025701f02fe745bd21df4f5f8196e2a6b621b85d Mon Sep 17 00:00:00 2001
From: Augustin Zidek <augustinzidek@google.com>
Date: Wed, 20 Nov 2024 14:15:10 +0000
Subject: [PATCH] Use model dir flag for run_alphafold_test

PiperOrigin-RevId: 698034193
---
 run_alphafold.py      | 16 ++++++++--------
 run_alphafold_test.py |  6 +++---
 2 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/run_alphafold.py b/run_alphafold.py
index 4480afc..0c0230a 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -56,8 +56,8 @@ import numpy as np
 
 
 _HOME_DIR = pathlib.Path(os.environ.get('HOME'))
-DEFAULT_MODEL_DIR = _HOME_DIR / 'models'
-DEFAULT_DB_DIR = _HOME_DIR / 'public_databases'
+_DEFAULT_MODEL_DIR = _HOME_DIR / 'models'
+_DEFAULT_DB_DIR = _HOME_DIR / 'public_databases'
 
 
 # Input and output paths.
@@ -77,9 +77,9 @@ _OUTPUT_DIR = flags.DEFINE_string(
     'Path to a directory where the results will be saved.',
 )
 
-_MODEL_DIR = flags.DEFINE_string(
+MODEL_DIR = flags.DEFINE_string(
     'model_dir',
-    DEFAULT_MODEL_DIR.as_posix(),
+    _DEFAULT_MODEL_DIR.as_posix(),
     'Path to the model to use for inference.',
 )
 
@@ -137,9 +137,9 @@ _HMMBUILD_BINARY_PATH = flags.DEFINE_string(
 )
 
 # Database paths.
-_DB_DIR = flags.DEFINE_multi_string(
+DB_DIR = flags.DEFINE_multi_string(
     'db_dir',
-    (DEFAULT_DB_DIR.as_posix(),),
+    (_DEFAULT_DB_DIR.as_posix(),),
     'Path to the directory containing the databases. Can be specified multiple'
     ' times to search multiple directories in order.',
 )
@@ -635,7 +635,7 @@ def main(_):
   print('\n'.join(notice))
 
   if _RUN_DATA_PIPELINE.value:
-    expand_path = lambda x: replace_db_dir(x, _DB_DIR.value)
+    expand_path = lambda x: replace_db_dir(x, DB_DIR.value)
     data_pipeline_config = pipeline.DataPipelineConfig(
         jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value,
         nhmmer_binary_path=_NHMMER_BINARY_PATH.value,
@@ -673,7 +673,7 @@ def main(_):
             )
         ),
         device=devices[0],
-        model_dir=pathlib.Path(_MODEL_DIR.value),
+        model_dir=pathlib.Path(MODEL_DIR.value),
     )
   else:
     print('Skipping running model inference.')
diff --git a/run_alphafold_test.py b/run_alphafold_test.py
index b279b9e..28f376a 100644
--- a/run_alphafold_test.py
+++ b/run_alphafold_test.py
@@ -184,7 +184,7 @@ class InferenceTest(test_utils.StructureTestCase):
         model_class=run_alphafold.diffusion_model.Diffuser,
         config=run_alphafold.make_model_config(),
         device=jax.local_devices()[0],
-        model_dir=run_alphafold.DEFAULT_MODEL_DIR,
+        model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
     )
 
   def compare_golden(self, result_path: str) -> None:
@@ -301,7 +301,7 @@ class InferenceTest(test_utils.StructureTestCase):
               model_class=diffusion_model.Diffuser,
               config=run_alphafold.make_model_config(),
               device=jax.local_devices(backend='gpu')[0],
-              model_dir=pathlib.Path(run_alphafold.DEFAULT_MODEL_DIR),
+              model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
           ),
           output_dir='unused output dir',
       )
@@ -332,7 +332,7 @@ class InferenceTest(test_utils.StructureTestCase):
             model_class=diffusion_model.Diffuser,
             config=run_alphafold.make_model_config(),
             device=jax.local_devices(backend='gpu')[0],
-            model_dir=pathlib.Path(run_alphafold.DEFAULT_MODEL_DIR),
+            model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
         ),
         output_dir=output_dir,
         buckets=None if bucket is None else [bucket],
-- 
GitLab