diff --git a/src/alphafold3/jax/attention/attention.py b/src/alphafold3/jax/attention/attention.py index 06c421b1c0577abcceaeabf93840b7b81bca3c0c..26320039000aa12c9aae34e8d8460334c7487642 100644 --- a/src/alphafold3/jax/attention/attention.py +++ b/src/alphafold3/jax/attention/attention.py @@ -125,7 +125,8 @@ def dot_product_attention( if implementation == "triton": if not triton_utils.has_triton_support(): raise ValueError( - "implementation='triton' is unsupported on this GPU generation." + "implementation='triton' for FlashAttention is unsupported on this" + " GPU generation. Please use implementation='xla' instead." ) return attention_triton.TritonFlashAttention()(*args, **kwargs)