From 35996125a0f0948ff8aa24d4ac05aba1a8dbf982 Mon Sep 17 00:00:00 2001 From: Augustin Zidek <augustinzidek@google.com> Date: Tue, 19 Nov 2024 11:45:38 +0000 Subject: [PATCH] Improve error message for flash attention implementation not supported PiperOrigin-RevId: 697937778 --- src/alphafold3/jax/attention/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/alphafold3/jax/attention/attention.py b/src/alphafold3/jax/attention/attention.py index 06c421b..2632003 100644 --- a/src/alphafold3/jax/attention/attention.py +++ b/src/alphafold3/jax/attention/attention.py @@ -125,7 +125,8 @@ def dot_product_attention( if implementation == "triton": if not triton_utils.has_triton_support(): raise ValueError( - "implementation='triton' is unsupported on this GPU generation." + "implementation='triton' for FlashAttention is unsupported on this" + " GPU generation. Please use implementation='xla' instead." ) return attention_triton.TritonFlashAttention()(*args, **kwargs) -- GitLab