Skip to content
Snippets Groups Projects
Commit 35996125 authored by Augustin Zidek's avatar Augustin Zidek
Browse files

Improve error message for flash attention implementation not supported

PiperOrigin-RevId: 697937778
parent 7e3067b8
No related branches found
No related tags found
1 merge request!1Cloned AlphaFold 3 repo into this one
......@@ -125,7 +125,8 @@ def dot_product_attention(
if implementation == "triton":
if not triton_utils.has_triton_support():
raise ValueError(
"implementation='triton' is unsupported on this GPU generation."
"implementation='triton' for FlashAttention is unsupported on this"
" GPU generation. Please use implementation='xla' instead."
)
return attention_triton.TritonFlashAttention()(*args, **kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment