From 35996125a0f0948ff8aa24d4ac05aba1a8dbf982 Mon Sep 17 00:00:00 2001
From: Augustin Zidek <augustinzidek@google.com>
Date: Tue, 19 Nov 2024 11:45:38 +0000
Subject: [PATCH] Improve error message for flash attention implementation not
 supported

PiperOrigin-RevId: 697937778
---
 src/alphafold3/jax/attention/attention.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/alphafold3/jax/attention/attention.py b/src/alphafold3/jax/attention/attention.py
index 06c421b..2632003 100644
--- a/src/alphafold3/jax/attention/attention.py
+++ b/src/alphafold3/jax/attention/attention.py
@@ -125,7 +125,8 @@ def dot_product_attention(
   if implementation == "triton":
     if not triton_utils.has_triton_support():
       raise ValueError(
-          "implementation='triton' is unsupported on this GPU generation."
+          "implementation='triton' for FlashAttention is unsupported on this"
+          " GPU generation. Please use implementation='xla' instead."
       )
     return attention_triton.TritonFlashAttention()(*args, **kwargs)
 
-- 
GitLab