diff --git a/docs/known_issues.md b/docs/known_issues.md index 23535a69e75f4ecd8f5b8dd1ad6efe7c00dac179..2f2372aaa0cf2d1f444475890d556f87fce01d6b 100644 --- a/docs/known_issues.md +++ b/docs/known_issues.md @@ -1,16 +1 @@ # Known Issues - -## Numerical Accuracy above 5,120 Tokens - -AlphaFold 3 does not currently support inference on inputs larger than 5,120 -tokens. An error will be raised if the input is larger than this threshold. - -This is due to a numerical issue with the custom Pallas kernel implementing the -Gated Linear Unit. The numerical issue only occurs at inputs above the 5,120 -tokens threshold, and results in degraded accuracy in the predicted structure. - -This numerical issue is unique to the single GPU configuration used in this -repository, and does not affect the results in the -[AlphaFold 3 paper](https://www.nature.com/articles/s41586-024-07487-w). - -We hope to resolve this issue soon and remove this check on input size. diff --git a/docs/performance.md b/docs/performance.md index c656c03555c92e361dcdb879f57728336907cdf8..935de538554ef8c70450f7567aae7b3bfbb3346a 100644 --- a/docs/performance.md +++ b/docs/performance.md @@ -94,6 +94,41 @@ V100 using the flag `--flash_attention_implementation=xla` in `run_alphafold.py`, this configuration has not been tested for numerical accuracy or throughput efficiency, so please proceed with caution. +## Compilation Buckets + +To avoid excessive re-compilation of the model, AlphaFold 3 implements +compilation buckets: ranges of input sizes using a single compilation of the +model. + +When featurising an input, AlphaFold 3 determines the smallest bucket the input +fits into, then adds any necessary padding. This may avoid re-compiling the +model when running inference on the input if it belongs to the same bucket as a +previously processed input. + +The configuration of bucket sizes involves a trade-off: more buckets leads to +more re-compilations of the model, but less padding. + +By default, the largest bucket size is 5,120 tokens. Processing inputs larger +than this maximum bucket size triggers the creation of a new bucket for exactly +that input size, and a re-compilation of the model. In this case, you may wish +to redefine the compilation bucket sizes via the `--buckets` flag in +`run_alphafold.py` to add additional larger bucket sizes. For example, suppose +you are running inference on inputs with token sizes: `5132, 5280, 5342`. Using +the default bucket sizes configured in `run_alphafold.py` will trigger three +separate model compilations, one for each unique token size. If instead you pass +in the following flag to `run_alphafold.py` + +``` +--buckets 256,512,768,1024,1280,1536,2048,2560,3072,3584,4096,4608,5120,5376 +``` + +when running inference on the above three input sizes, the model will be +compiled only once for the bucket size `5376`. **Note:** for this specific +example with input sizes `5132, 5280, 5342`, passing in `--buckets 5376` is +sufficient to achieve the desired compilation behaviour. The provided example +with multiple buckets illustrates a more general solution suitable for diverse +input sizes. + ## Additional Flags ### Compilation Time Workaround with XLA Flags @@ -109,8 +144,8 @@ ENV XLA_FLAGS="--xla_gpu_enable_triton_gemm=false" ### GPU Memory The following environment variables (set by default in the `Dockerfile`) enable -folding a single input of size up to 5,120 tokens on a single A100 with 80 GB of -memory: +folding a single input of size up to 5,120 tokens on a single A100 (80 GB) or a +single H100 (80 GB): ```sh ENV XLA_PYTHON_CLIENT_PREALLOCATE=true @@ -119,12 +154,12 @@ ENV XLA_CLIENT_MEM_FRACTION=0.95 #### Unified Memory -If you would like to run AlphaFold 3 on a GPU with less memory (an A100 with 40 -GB of memory, for instance), we recommend enabling unified memory. Enabling -unified memory allows the program to spill GPU memory to host memory if there -isn't enough space. This prevents an OOM, at the cost of making the program -slower by accessing host memory instead of device memory. To learn more, check -out the +If you would like to run AlphaFold 3 on inputs larger than 5,120 tokens, or on a +GPU with less memory (an A100 with 40 GB of memory, for instance), we recommend +enabling unified memory. Enabling unified memory allows the program to spill GPU +memory to host memory if there isn't enough space. This prevents an OOM, at the +cost of making the program slower by accessing host memory instead of device +memory. To learn more, check out the [NVIDIA blog post](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/). You can enable unified memory by setting the following environment variables in diff --git a/run_alphafold.py b/run_alphafold.py index 7dd4710dac40832242239728e04db84728f9ee0e..69a1a988d2fee5ae12861d685c9cd5af28c7bf72 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -31,7 +31,7 @@ import string import textwrap import time import typing -from typing import Final, Protocol, Self, TypeVar, overload +from typing import Protocol, Self, TypeVar, overload from absl import app from absl import flags @@ -203,27 +203,23 @@ _NHMMER_N_CPU = flags.DEFINE_integer( ' beyond 8 CPUs provides very little additional speedup.', ) -# Compilation cache +# Compilation cache. _JAX_COMPILATION_CACHE_DIR = flags.DEFINE_string( 'jax_compilation_cache_dir', None, 'Path to a directory for the JAX compilation cache.', ) -_BUCKETS: Final[tuple[int, ...]] = ( - 256, - 512, - 768, - 1024, - 1280, - 1536, - 2048, - 2560, - 3072, - 3584, - 4096, - 4608, - 5120, +# Compilation buckets. +_BUCKETS = flags.DEFINE_list( + 'buckets', + # pyformat: disable + ['256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072', + '3584', '4096', '4608', '5120'], + # pyformat: enable + 'Strictly increasing order of token sizes for which to cache compilations.' + ' For any input with more tokens than the largest bucket size, a new bucket' + ' is created for exactly that number of tokens.', ) @@ -665,7 +661,7 @@ def main(_): data_pipeline_config=data_pipeline_config, model_runner=model_runner, output_dir=os.path.join(_OUTPUT_DIR.value, fold_input.sanitised_name()), - buckets=_BUCKETS, + buckets=tuple(int(bucket) for bucket in _BUCKETS.value), ) print(f'Done processing {len(fold_inputs)} fold inputs.') diff --git a/src/alphafold3/jax/common/array_view.py b/src/alphafold3/jax/common/array_view.py index 440db0cccf898694abc1c236f51e89ec7e60239a..c37c299cc080635bbba2fd3603c8130cce6107dc 100644 --- a/src/alphafold3/jax/common/array_view.py +++ b/src/alphafold3/jax/common/array_view.py @@ -18,6 +18,7 @@ from types import EllipsisType # pylint: disable=g-importing-member from typing import Any, Self, TypeAlias, TypeVar import jax +import jax.experimental from jax.experimental import pallas as pl import jax.numpy as jnp from jax.typing import ArrayLike # pylint: disable=g-importing-member @@ -91,11 +92,17 @@ class ArrayView: def T(self) -> Self: # pylint: disable=invalid-name return self.transpose() + @property + def _index_dtype(self) -> jax.typing.DTypeLike: + i32_max = jnp.iinfo(jnp.int32).max + return jnp.int32 if (self.base.size <= i32_max) else jnp.int64 + @property def offsets(self) -> jax.Array: """Returns array of offsets into `base` for each element.""" - idxs = jnp.indices(self.shape, sparse=True) - return self.offset + sum(s * idx for s, idx in zip(self.strides, idxs)) + with jax.experimental.enable_x64(): + idxs = jnp.indices(self.shape, sparse=True, dtype=self._index_dtype) + return self.offset + sum(s * idx for s, idx in zip(self.strides, idxs)) def astype(self, dtype: jax.typing.DTypeLike) -> Self: return self._replace(base=self.base.astype(dtype)) @@ -255,29 +262,34 @@ class ArrayView: shape = [] strides = [] - offset = self.offset - - for idx, dim, stride in zip(idxs, self.shape, self.strides, strict=True): - if isinstance(idx, int): - if not (-dim <= idx < dim): - raise ValueError("Slice index out of range.") - offset += stride * (idx % dim) - elif isinstance(idx, ScalarInt): - offset += stride * idx - elif isinstance(idx, slice): - start, stop, step = idx.indices(dim) - if step >= 0: - shape.append(pl.cdiv(stop - start, step)) + with jax.experimental.enable_x64(): + + def as_index(x): + return x.astype(self._index_dtype) if isinstance(x, jax.Array) else x + + offset = as_index(self.offset) + + for idx, dim, stride in zip(idxs, self.shape, self.strides, strict=True): + if isinstance(idx, int): + if not (-dim <= idx < dim): + raise ValueError("Slice index out of range.") + offset += stride * (idx % dim) + elif isinstance(idx, ScalarInt): + offset += stride * as_index(idx) + elif isinstance(idx, slice): + start, stop, step = idx.indices(dim) + if step >= 0: + shape.append(pl.cdiv(stop - start, step)) + else: + shape.append(pl.cdiv(start - stop, -step)) + strides.append(stride * step) + offset += stride * start + elif isinstance(idx, pl.Slice): + shape.append(idx.size) + strides.append(stride * idx.stride) + offset += stride * as_index(idx.start) else: - shape.append(pl.cdiv(start - stop, -step)) - strides.append(stride * step) - offset += stride * start - elif isinstance(idx, pl.Slice): - shape.append(idx.size) - strides.append(stride * idx.stride) - offset += stride * idx.start - else: - raise ValueError(f"Unexpected indexer: {idx}") + raise ValueError(f"Unexpected indexer: {idx}") return self._replace(shape=shape, strides=strides, offset=offset) diff --git a/src/alphafold3/jax/gated_linear_unit/block.py b/src/alphafold3/jax/gated_linear_unit/block.py index 56406d5d47f98c13ba9b279e3c132c95b3082c20..43c7e79e7e160d9a0e070b73080c509daf063389 100644 --- a/src/alphafold3/jax/gated_linear_unit/block.py +++ b/src/alphafold3/jax/gated_linear_unit/block.py @@ -15,6 +15,7 @@ from typing import Any, TypeAlias from alphafold3.jax.common import array_view import jax +import jax.experimental from jax.experimental import pallas as pl import jax.numpy as jnp import jaxtyping @@ -43,7 +44,8 @@ def load_block( idx = ref[idx].offsets ref = ref.base other = None if mask is None else other - return pl.load(ref, idx, mask=mask, other=other, **kwargs) + with jax.experimental.enable_x64(): + return pl.load(ref, idx, mask=mask, other=other, **kwargs) @jaxtyping.jaxtyped(typechecker=typeguard.typechecked) @@ -62,7 +64,8 @@ def store_block( if isinstance(ref, array_view.ArrayView): idx = ref[idx].offsets ref = ref.base - pl.store(ref, idx, val.astype(ref.dtype), mask=mask, **kwargs) + with jax.experimental.enable_x64(): + pl.store(ref, idx, val.astype(ref.dtype), mask=mask, **kwargs) def in_bounds_mask( diff --git a/src/alphafold3/model/pipeline/pipeline.py b/src/alphafold3/model/pipeline/pipeline.py index fd19650c56cc1168fe2f9b1c65b14cc0e3dd0b37..539968a78ca19a254cb1f087b0ca6b03912578e0 100644 --- a/src/alphafold3/model/pipeline/pipeline.py +++ b/src/alphafold3/model/pipeline/pipeline.py @@ -48,10 +48,15 @@ def calculate_bucket_size( bucket_idx = bisect.bisect_left(buckets, num_tokens) if bucket_idx == len(buckets): - raise ValueError( - f'Number of tokens {num_tokens} is more than the largest currently' - f' supported bucket size {buckets[-1]}.' + logging.warning( + 'Creating a new bucket of size %d since the input has more tokens than' + ' the largest bucket size %d. This may trigger a re-compilation of the' + ' model. Consider additional large bucket sizes to avoid excessive' + ' re-compilation.', + num_tokens, + buckets[-1], ) + return num_tokens return buckets[bucket_idx] @@ -250,9 +255,19 @@ class WholePdbPipeline: f'({total_tokens} < {self._config.min_total_residues})' ) + logging.info( + 'Calculating bucket size for input with %d tokens.', total_tokens + ) padded_token_length = calculate_bucket_size( total_tokens, self._config.buckets ) + logging.info( + 'Got bucket size %d for input with %d tokens, resulting in %d padded' + ' tokens.', + padded_token_length, + total_tokens, + padded_token_length - total_tokens, + ) # Padding shapes for all features. num_atoms = padded_token_length * self._config.average_num_atoms_per_token