diff --git a/src/alphafold3/model/diffusion/model.py b/src/alphafold3/model/diffusion/model.py
index 157dcc3959339b317716f5e0acbebd4c2fce4b31..b6ab1495c13863e0b52a08cf0e8cf449cb0dcbf7 100644
--- a/src/alphafold3/model/diffusion/model.py
+++ b/src/alphafold3/model/diffusion/model.py
@@ -232,7 +232,7 @@ class Diffuser(hk.Module):
     return sample
 
   def __call__(
-      self, batch: features.BatchDict, key: bool = None
+      self, batch: features.BatchDict, key: jax.Array | None = None
   ) -> base_model.ModelResult:
     if key is None:
       key = hk.next_rng_key()