diff --git a/src/alphafold3/model/diffusion/model.py b/src/alphafold3/model/diffusion/model.py index 157dcc3959339b317716f5e0acbebd4c2fce4b31..b6ab1495c13863e0b52a08cf0e8cf449cb0dcbf7 100644 --- a/src/alphafold3/model/diffusion/model.py +++ b/src/alphafold3/model/diffusion/model.py @@ -232,7 +232,7 @@ class Diffuser(hk.Module): return sample def __call__( - self, batch: features.BatchDict, key: bool = None + self, batch: features.BatchDict, key: jax.Array | None = None ) -> base_model.ModelResult: if key is None: key = hk.next_rng_key()