diff --git a/docs/known_issues.md b/docs/known_issues.md
index 23535a69e75f4ecd8f5b8dd1ad6efe7c00dac179..2f2372aaa0cf2d1f444475890d556f87fce01d6b 100644
--- a/docs/known_issues.md
+++ b/docs/known_issues.md
@@ -1,16 +1 @@
 # Known Issues
-
-## Numerical Accuracy above 5,120 Tokens
-
-AlphaFold 3 does not currently support inference on inputs larger than 5,120
-tokens. An error will be raised if the input is larger than this threshold.
-
-This is due to a numerical issue with the custom Pallas kernel implementing the
-Gated Linear Unit. The numerical issue only occurs at inputs above the 5,120
-tokens threshold, and results in degraded accuracy in the predicted structure.
-
-This numerical issue is unique to the single GPU configuration used in this
-repository, and does not affect the results in the
-[AlphaFold 3 paper](https://www.nature.com/articles/s41586-024-07487-w).
-
-We hope to resolve this issue soon and remove this check on input size.
diff --git a/docs/performance.md b/docs/performance.md
index c656c03555c92e361dcdb879f57728336907cdf8..935de538554ef8c70450f7567aae7b3bfbb3346a 100644
--- a/docs/performance.md
+++ b/docs/performance.md
@@ -94,6 +94,41 @@ V100 using the flag `--flash_attention_implementation=xla` in
 `run_alphafold.py`, this configuration has not been tested for numerical
 accuracy or throughput efficiency, so please proceed with caution.
 
+## Compilation Buckets
+
+To avoid excessive re-compilation of the model, AlphaFold 3 implements
+compilation buckets: ranges of input sizes using a single compilation of the
+model.
+
+When featurising an input, AlphaFold 3 determines the smallest bucket the input
+fits into, then adds any necessary padding. This may avoid re-compiling the
+model when running inference on the input if it belongs to the same bucket as a
+previously processed input.
+
+The configuration of bucket sizes involves a trade-off: more buckets leads to
+more re-compilations of the model, but less padding.
+
+By default, the largest bucket size is 5,120 tokens. Processing inputs larger
+than this maximum bucket size triggers the creation of a new bucket for exactly
+that input size, and a re-compilation of the model. In this case, you may wish
+to redefine the compilation bucket sizes via the `--buckets` flag in
+`run_alphafold.py` to add additional larger bucket sizes. For example, suppose
+you are running inference on inputs with token sizes: `5132, 5280, 5342`. Using
+the default bucket sizes configured in `run_alphafold.py` will trigger three
+separate model compilations, one for each unique token size. If instead you pass
+in the following flag to `run_alphafold.py`
+
+```
+--buckets 256,512,768,1024,1280,1536,2048,2560,3072,3584,4096,4608,5120,5376
+```
+
+when running inference on the above three input sizes, the model will be
+compiled only once for the bucket size `5376`. **Note:** for this specific
+example with input sizes `5132, 5280, 5342`, passing in `--buckets 5376` is
+sufficient to achieve the desired compilation behaviour. The provided example
+with multiple buckets illustrates a more general solution suitable for diverse
+input sizes.
+
 ## Additional Flags
 
 ### Compilation Time Workaround with XLA Flags
@@ -109,8 +144,8 @@ ENV XLA_FLAGS="--xla_gpu_enable_triton_gemm=false"
 ### GPU Memory
 
 The following environment variables (set by default in the `Dockerfile`) enable
-folding a single input of size up to 5,120 tokens on a single A100 with 80 GB of
-memory:
+folding a single input of size up to 5,120 tokens on a single A100 (80 GB) or a
+single H100 (80 GB):
 
 ```sh
 ENV XLA_PYTHON_CLIENT_PREALLOCATE=true
@@ -119,12 +154,12 @@ ENV XLA_CLIENT_MEM_FRACTION=0.95
 
 #### Unified Memory
 
-If you would like to run AlphaFold 3 on a GPU with less memory (an A100 with 40
-GB of memory, for instance), we recommend enabling unified memory. Enabling
-unified memory allows the program to spill GPU memory to host memory if there
-isn't enough space. This prevents an OOM, at the cost of making the program
-slower by accessing host memory instead of device memory. To learn more, check
-out the
+If you would like to run AlphaFold 3 on inputs larger than 5,120 tokens, or on a
+GPU with less memory (an A100 with 40 GB of memory, for instance), we recommend
+enabling unified memory. Enabling unified memory allows the program to spill GPU
+memory to host memory if there isn't enough space. This prevents an OOM, at the
+cost of making the program slower by accessing host memory instead of device
+memory. To learn more, check out the
 [NVIDIA blog post](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/).
 
 You can enable unified memory by setting the following environment variables in
diff --git a/run_alphafold.py b/run_alphafold.py
index 7dd4710dac40832242239728e04db84728f9ee0e..69a1a988d2fee5ae12861d685c9cd5af28c7bf72 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -31,7 +31,7 @@ import string
 import textwrap
 import time
 import typing
-from typing import Final, Protocol, Self, TypeVar, overload
+from typing import Protocol, Self, TypeVar, overload
 
 from absl import app
 from absl import flags
@@ -203,27 +203,23 @@ _NHMMER_N_CPU = flags.DEFINE_integer(
     ' beyond 8 CPUs provides very little additional speedup.',
 )
 
-# Compilation cache
+# Compilation cache.
 _JAX_COMPILATION_CACHE_DIR = flags.DEFINE_string(
     'jax_compilation_cache_dir',
     None,
     'Path to a directory for the JAX compilation cache.',
 )
 
-_BUCKETS: Final[tuple[int, ...]] = (
-    256,
-    512,
-    768,
-    1024,
-    1280,
-    1536,
-    2048,
-    2560,
-    3072,
-    3584,
-    4096,
-    4608,
-    5120,
+# Compilation buckets.
+_BUCKETS = flags.DEFINE_list(
+    'buckets',
+    # pyformat: disable
+    ['256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072',
+     '3584', '4096', '4608', '5120'],
+    # pyformat: enable
+    'Strictly increasing order of token sizes for which to cache compilations.'
+    ' For any input with more tokens than the largest bucket size, a new bucket'
+    ' is created for exactly that number of tokens.',
 )
 
 
@@ -665,7 +661,7 @@ def main(_):
         data_pipeline_config=data_pipeline_config,
         model_runner=model_runner,
         output_dir=os.path.join(_OUTPUT_DIR.value, fold_input.sanitised_name()),
-        buckets=_BUCKETS,
+        buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
     )
 
   print(f'Done processing {len(fold_inputs)} fold inputs.')
diff --git a/src/alphafold3/jax/common/array_view.py b/src/alphafold3/jax/common/array_view.py
index 440db0cccf898694abc1c236f51e89ec7e60239a..c37c299cc080635bbba2fd3603c8130cce6107dc 100644
--- a/src/alphafold3/jax/common/array_view.py
+++ b/src/alphafold3/jax/common/array_view.py
@@ -18,6 +18,7 @@ from types import EllipsisType  # pylint: disable=g-importing-member
 from typing import Any, Self, TypeAlias, TypeVar
 
 import jax
+import jax.experimental
 from jax.experimental import pallas as pl
 import jax.numpy as jnp
 from jax.typing import ArrayLike  # pylint: disable=g-importing-member
@@ -91,11 +92,17 @@ class ArrayView:
   def T(self) -> Self:  # pylint: disable=invalid-name
     return self.transpose()
 
+  @property
+  def _index_dtype(self) -> jax.typing.DTypeLike:
+    i32_max = jnp.iinfo(jnp.int32).max
+    return jnp.int32 if (self.base.size <= i32_max) else jnp.int64
+
   @property
   def offsets(self) -> jax.Array:
     """Returns array of offsets into `base` for each element."""
-    idxs = jnp.indices(self.shape, sparse=True)
-    return self.offset + sum(s * idx for s, idx in zip(self.strides, idxs))
+    with jax.experimental.enable_x64():
+      idxs = jnp.indices(self.shape, sparse=True, dtype=self._index_dtype)
+      return self.offset + sum(s * idx for s, idx in zip(self.strides, idxs))
 
   def astype(self, dtype: jax.typing.DTypeLike) -> Self:
     return self._replace(base=self.base.astype(dtype))
@@ -255,29 +262,34 @@ class ArrayView:
 
     shape = []
     strides = []
-    offset = self.offset
-
-    for idx, dim, stride in zip(idxs, self.shape, self.strides, strict=True):
-      if isinstance(idx, int):
-        if not (-dim <= idx < dim):
-          raise ValueError("Slice index out of range.")
-        offset += stride * (idx % dim)
-      elif isinstance(idx, ScalarInt):
-        offset += stride * idx
-      elif isinstance(idx, slice):
-        start, stop, step = idx.indices(dim)
-        if step >= 0:
-          shape.append(pl.cdiv(stop - start, step))
+    with jax.experimental.enable_x64():
+
+      def as_index(x):
+        return x.astype(self._index_dtype) if isinstance(x, jax.Array) else x
+
+      offset = as_index(self.offset)
+
+      for idx, dim, stride in zip(idxs, self.shape, self.strides, strict=True):
+        if isinstance(idx, int):
+          if not (-dim <= idx < dim):
+            raise ValueError("Slice index out of range.")
+          offset += stride * (idx % dim)
+        elif isinstance(idx, ScalarInt):
+          offset += stride * as_index(idx)
+        elif isinstance(idx, slice):
+          start, stop, step = idx.indices(dim)
+          if step >= 0:
+            shape.append(pl.cdiv(stop - start, step))
+          else:
+            shape.append(pl.cdiv(start - stop, -step))
+          strides.append(stride * step)
+          offset += stride * start
+        elif isinstance(idx, pl.Slice):
+          shape.append(idx.size)
+          strides.append(stride * idx.stride)
+          offset += stride * as_index(idx.start)
         else:
-          shape.append(pl.cdiv(start - stop, -step))
-        strides.append(stride * step)
-        offset += stride * start
-      elif isinstance(idx, pl.Slice):
-        shape.append(idx.size)
-        strides.append(stride * idx.stride)
-        offset += stride * idx.start
-      else:
-        raise ValueError(f"Unexpected indexer: {idx}")
+          raise ValueError(f"Unexpected indexer: {idx}")
 
     return self._replace(shape=shape, strides=strides, offset=offset)
 
diff --git a/src/alphafold3/jax/gated_linear_unit/block.py b/src/alphafold3/jax/gated_linear_unit/block.py
index 56406d5d47f98c13ba9b279e3c132c95b3082c20..43c7e79e7e160d9a0e070b73080c509daf063389 100644
--- a/src/alphafold3/jax/gated_linear_unit/block.py
+++ b/src/alphafold3/jax/gated_linear_unit/block.py
@@ -15,6 +15,7 @@ from typing import Any, TypeAlias
 
 from alphafold3.jax.common import array_view
 import jax
+import jax.experimental
 from jax.experimental import pallas as pl
 import jax.numpy as jnp
 import jaxtyping
@@ -43,7 +44,8 @@ def load_block(
     idx = ref[idx].offsets
     ref = ref.base
   other = None if mask is None else other
-  return pl.load(ref, idx, mask=mask, other=other, **kwargs)
+  with jax.experimental.enable_x64():
+    return pl.load(ref, idx, mask=mask, other=other, **kwargs)
 
 
 @jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
@@ -62,7 +64,8 @@ def store_block(
   if isinstance(ref, array_view.ArrayView):
     idx = ref[idx].offsets
     ref = ref.base
-  pl.store(ref, idx, val.astype(ref.dtype), mask=mask, **kwargs)
+  with jax.experimental.enable_x64():
+    pl.store(ref, idx, val.astype(ref.dtype), mask=mask, **kwargs)
 
 
 def in_bounds_mask(
diff --git a/src/alphafold3/model/pipeline/pipeline.py b/src/alphafold3/model/pipeline/pipeline.py
index fd19650c56cc1168fe2f9b1c65b14cc0e3dd0b37..539968a78ca19a254cb1f087b0ca6b03912578e0 100644
--- a/src/alphafold3/model/pipeline/pipeline.py
+++ b/src/alphafold3/model/pipeline/pipeline.py
@@ -48,10 +48,15 @@ def calculate_bucket_size(
   bucket_idx = bisect.bisect_left(buckets, num_tokens)
 
   if bucket_idx == len(buckets):
-    raise ValueError(
-        f'Number of tokens {num_tokens} is more than the largest currently'
-        f' supported bucket size {buckets[-1]}.'
+    logging.warning(
+        'Creating a new bucket of size %d since the input has more tokens than'
+        ' the largest bucket size %d. This may trigger a re-compilation of the'
+        ' model. Consider additional large bucket sizes to avoid excessive'
+        ' re-compilation.',
+        num_tokens,
+        buckets[-1],
     )
+    return num_tokens
 
   return buckets[bucket_idx]
 
@@ -250,9 +255,19 @@ class WholePdbPipeline:
           f'({total_tokens} < {self._config.min_total_residues})'
       )
 
+    logging.info(
+        'Calculating bucket size for input with %d tokens.', total_tokens
+    )
     padded_token_length = calculate_bucket_size(
         total_tokens, self._config.buckets
     )
+    logging.info(
+        'Got bucket size %d for input with %d tokens, resulting in %d padded'
+        ' tokens.',
+        padded_token_length,
+        total_tokens,
+        padded_token_length - total_tokens,
+    )
 
     # Padding shapes for all features.
     num_atoms = padded_token_length * self._config.average_num_atoms_per_token