From b618d2d57632b84c971ed1aae12d091fbcc7c275 Mon Sep 17 00:00:00 2001
From: Augustin Zidek <augustinzidek@google.com>
Date: Mon, 18 Nov 2024 16:19:04 +0000
Subject: [PATCH] Fix type annotation

PiperOrigin-RevId: 697633892
---
 src/alphafold3/model/diffusion/model.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/alphafold3/model/diffusion/model.py b/src/alphafold3/model/diffusion/model.py
index 157dcc3..b6ab149 100644
--- a/src/alphafold3/model/diffusion/model.py
+++ b/src/alphafold3/model/diffusion/model.py
@@ -232,7 +232,7 @@ class Diffuser(hk.Module):
     return sample
 
   def __call__(
-      self, batch: features.BatchDict, key: bool = None
+      self, batch: features.BatchDict, key: jax.Array | None = None
   ) -> base_model.ModelResult:
     if key is None:
       key = hk.next_rng_key()
-- 
GitLab