diff --git a/src/mistral_inference/args.py b/src/mistral_inference/args.py
index a94a2c605977cda5323230a924a83e53adbc51bd..7ca4f43801aa13b5dd6a08c823d7d6ef2a186384 100644
--- a/src/mistral_inference/args.py
+++ b/src/mistral_inference/args.py
@@ -1,5 +1,5 @@
 from dataclasses import dataclass
-from typing import Optional
+from typing import List, Optional
 
 from simple_parsing.helpers import Serializable
 
@@ -39,12 +39,18 @@ class TransformerArgs(Serializable):
     moe: Optional[MoeArgs] = None
     # If this is set, we will load LoRA linear layers instead of linear layers.
     lora: Optional[LoraArgs] = None
+    sliding_window: Optional[int] | Optional[List[int]] = None
+    _sliding_window: Optional[int] | Optional[List[int]] = None
     model_type: str = "transformer"
 
     vision_encoder: Optional[VisionEncoderArgs] = None
 
     def __post_init__(self) -> None:
         assert self.model_type == "transformer", self.model_type
+        assert self.sliding_window is None or self._sliding_window is None
+
+        # hack for now so that vLLM is supported correctly
+        self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window
 
 
 @dataclass
diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py
index 93cfb1c102a4cb2199d408866a12ed75a919f4f9..6f8aa7d2c18186ea4c1e9294605659bf9f3a8674 100644
--- a/src/mistral_inference/cache.py
+++ b/src/mistral_inference/cache.py
@@ -10,13 +10,40 @@ from xformers.ops.fmha.attn_bias import (  # type: ignore
 )
 
 
+def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[int] | Optional[List[int]]) -> List[int]:
+    if sliding_window is None:
+        return n_layers * [max_seq_len]
+    elif isinstance(sliding_window, int):
+        return n_layers * [sliding_window]
+    else:
+        assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}"
+        assert (
+            n_layers % len(sliding_window) == 0
+        ), f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}"
+        num_repeats = n_layers // len(sliding_window)
+        return num_repeats * [w if w is not None else max_seq_len for w in sliding_window]
+
+
 @dataclass
 class CacheInputMetadata:
+    # # rope absolute positions
+    # positions: torch.Tensor
+    # # where tokens should go in the cache
+    # cache_positions: torch.Tensor
+
+    # # if prefill, use block diagonal causal mask
+    # # else use causal with padded key mask
+    # prefill: bool
+    # mask: AttentionBias
+    # seqlens: List[int]
     # rope absolute positions
     positions: torch.Tensor
+    # which elements in the sequences need to be cached
+    to_cache_mask: torch.Tensor
+    # how many elements are cached per sequence
+    cached_elements: torch.Tensor
     # where tokens should go in the cache
     cache_positions: torch.Tensor
-
     # if prefill, use block diagonal causal mask
     # else use causal with padded key mask
     prefill: bool
@@ -29,6 +56,17 @@ def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torc
     return [v for pair in zip(l1, l2) for v in pair]
 
 
+def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor:
+    assert cache.ndim == 3  # (W, H, D)
+    position = seqlen % cache.shape[0]
+    if seqlen < cache.shape[0]:
+        return cache[:seqlen]
+    elif position == 0:
+        return cache
+    else:
+        return torch.cat([cache[position:], cache[:position]], dim=0)
+
+
 class CacheView:
     def __init__(
         self,
@@ -50,8 +88,8 @@ class CacheView:
         flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim)
         flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim)
 
-        flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk)
-        flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv)
+        flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
+        flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])
 
     def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         """
@@ -69,9 +107,9 @@ class CacheView:
         xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens)  # type: ignore
         assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"
 
-        # Retrieve cache
-        cache_k = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)]
-        cache_v = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)]
+        # Order elements in cache by position by unrotating
+        cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
+        cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]
 
         interleaved_k = interleave_list(cache_k, list(xk))
         interleaved_v = interleave_list(cache_v, list(xv))
@@ -112,13 +150,22 @@ class BufferCache:
         max_seq_len: int,
         n_kv_heads: int,
         head_dim: int,
+        sliding_window: Optional[int] | Optional[List[int]] = None,
     ):
         self.max_seq_len = max_seq_len
         self.n_kv_heads = n_kv_heads
         self.head_dim = head_dim
+        self.n_layers = n_layers
+
+        self.cache_sizes: List[int] = get_cache_sizes(n_layers, max_seq_len, sliding_window)
+        assert len(self.cache_sizes) == n_layers, f"Expected {n_layers} cache sizes, got {len(self.cache_sizes)}"
+
+        self.cache_k = {}
+        self.cache_v = {}
+        for i, cache_size in enumerate(self.cache_sizes):
+            self.cache_k[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
+            self.cache_v[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
 
-        self.cache_k = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
-        self.cache_v = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
         # holds the valid length for each batch element in the cache
         self.kv_seqlens: Optional[torch.Tensor] = None
 
@@ -134,11 +181,12 @@ class BufferCache:
 
     @property
     def device(self) -> torch.device:
-        return self.cache_k.device
+        return self.cache_k[0].device
 
     def to(self, device: torch.device, dtype: torch.dtype) -> "BufferCache":
-        self.cache_k = self.cache_k.to(device=device, dtype=dtype)
-        self.cache_v = self.cache_v.to(device=device, dtype=dtype)
+        for i in range(self.n_layers):
+            self.cache_k[i] = self.cache_k[i].to(device=device, dtype=dtype)
+            self.cache_v[i] = self.cache_v[i].to(device=device, dtype=dtype)
 
         return self
 
@@ -146,55 +194,69 @@ class BufferCache:
         assert self.kv_seqlens is not None
         self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)
 
-    def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata:
+    def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]:
         """
-        Get metadata about cache positions
+        input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
+        --> only cache last 3 tokens in each sequence
+        - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
+        - cached_elements = [3 | 3 | 2]
+        --> absolute positions are used for rope
+        - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
+        --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
+        - cache_positions = [2 0 1 | 5 3 4 | 6 7]
         """
+        metadata: List[CacheInputMetadata] = []
+
         if self.kv_seqlens is None:
             self.init_kvseqlens(len(seqlens))
 
-        assert isinstance(self.kv_seqlens, torch.Tensor)
+        assert self.kv_seqlens is not None
         assert len(seqlens) == len(
             self.kv_seqlens
         ), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
         seqpos = self.kv_seqlens.tolist()
-
         assert len(seqlens) > 0, seqlens
-        cached_elements = torch.tensor(seqlens, device=self.device, dtype=torch.long)
 
+        for cache_size in self.cache_sizes:
+            metadata.append(self._get_input_metadata_layer(cache_size, seqlens, seqpos))
+
+        return metadata
+
+    def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata:
+        masks = [[x >= seqlen - cache_size for x in range(seqlen)] for seqlen in seqlens]
+        to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
+        cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
         positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(
             device=self.device, dtype=torch.long
         )
-
         batch_idx = torch.tensor(
-            sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []),
-            device=self.device,
-            dtype=torch.long,
+            sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long
         )
-        cache_positions = positions + batch_idx * self.max_seq_len
-
+        cache_positions = positions % cache_size + batch_idx * cache_size
         first_prefill = seqpos[0] == 0
         subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
         if first_prefill:
             assert all([pos == 0 for pos in seqpos]), seqpos
-            mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len)
+            mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size)
         elif subsequent_prefill:
+            assert self.kv_seqlens is not None
             mask = BlockDiagonalMask.from_seqlens(
                 q_seqlen=seqlens,
                 kv_seqlen=[
-                    s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
+                    s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
                 ],
-            ).make_local_attention_from_bottomright(self.max_seq_len)
+            ).make_local_attention_from_bottomright(cache_size)
         else:
             mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
                 q_seqlen=seqlens,
-                kv_padding=self.max_seq_len,
-                kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(),
+                kv_padding=cache_size,
+                kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist(),
             )
-
         return CacheInputMetadata(
             positions=positions,
-            cache_positions=cache_positions,
+            to_cache_mask=to_cache_mask,
+            cached_elements=cached_elements,
+            cache_positions=cache_positions[to_cache_mask],
             prefill=first_prefill or subsequent_prefill,
             mask=mask,
             seqlens=seqlens,
diff --git a/src/mistral_inference/generate.py b/src/mistral_inference/generate.py
index 1e906b3c136f259bce0a6f7b8ec1704a86797bda..bc6112dbef7873d9ab29d8a6061add62d13c0672 100644
--- a/src/mistral_inference/generate.py
+++ b/src/mistral_inference/generate.py
@@ -72,6 +72,7 @@ def generate(
         cache_window,
         model.args.n_kv_heads,
         model.args.head_dim,
+        model.args.sliding_window,
     )
     cache.to(device=model.device, dtype=model.dtype)
     cache.reset()
diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py
index 9c9aebec3ad86ff3c50945652f218bd6c5229cdd..cb782ddb0c34cb9905c1e7111c24db2b010b3a7c 100644
--- a/src/mistral_inference/transformer.py
+++ b/src/mistral_inference/transformer.py
@@ -36,6 +36,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
         args: TransformerArgs,
         pipeline_rank: int = 0,
         num_pipeline_ranks: int = 1,
+        softmax_fp32: bool = True,
     ):
         super().__init__()
         self.args = args
@@ -46,6 +47,8 @@ class Transformer(ModelBase, LoRALoaderMixin):
         assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
         self.pipeline_rank = pipeline_rank
         self.num_pipeline_ranks = num_pipeline_ranks
+        self.softmax_fp32 = softmax_fp32
+
         # Modules specific to some ranks:
         self.tok_embeddings: Optional[nn.Embedding] = None
         self.norm: Optional[RMSNorm] = None
@@ -150,12 +153,12 @@ class Transformer(ModelBase, LoRALoaderMixin):
         (num_toks,) = input_ids.shape
         assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
 
-        input_metadata: Union[CacheInputMetadata, SimpleInputMetadata]
+        input_metadata: List[CacheInputMetadata] | List[SimpleInputMetadata]
 
         if cache is not None:
             input_metadata = cache.get_input_metadata(seqlens)
         else:
-            input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
+            input_metadata = [SimpleInputMetadata.from_seqlens(seqlens, self.device) for _ in range(len(self.layers))]
 
         if self.pipeline_rank == 0:
             assert self.tok_embeddings is not None
@@ -167,13 +170,15 @@ class Transformer(ModelBase, LoRALoaderMixin):
             h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype)
             torch.distributed.recv(h, src=self.pipeline_rank - 1)
 
-        freqs_cis = self.freqs_cis[input_metadata.positions]
+        # freqs_cis is always the same for every layer
+        freqs_cis = self.freqs_cis[input_metadata[0].positions]
 
         for local_layer_id, layer in enumerate(self.layers.values()):
             if cache is not None:
                 assert input_metadata is not None
-                assert isinstance(input_metadata, CacheInputMetadata)
-                cache_view = cache.get_view(local_layer_id, input_metadata)
+                cache_metadata = input_metadata[local_layer_id]
+                assert isinstance(cache_metadata, CacheInputMetadata)
+                cache_view = cache.get_view(local_layer_id, cache_metadata)
             else:
                 cache_view = None
             h = layer(h, freqs_cis, cache_view)
@@ -205,7 +210,11 @@ class Transformer(ModelBase, LoRALoaderMixin):
             outs = self.output(h)
         if self.num_pipeline_ranks > 1:
             torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
-        return outs.float()
+
+        if self.softmax_fp32:
+            return outs.float()
+        else:
+            return outs
 
     def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None:
         state_to_load = {}
@@ -257,6 +266,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
         num_pipeline_ranks: int = 1,
         device: Union[torch.device, str] = "cuda",
         dtype: Optional[torch.dtype] = None,
+        softmax_fp32: bool = True,
     ) -> "Transformer":
         with open(Path(folder) / "params.json", "r") as f:
             model_args = TransformerArgs.from_dict(json.load(f))
@@ -270,6 +280,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
                 model_args,
                 pipeline_rank=pipeline_rank,
                 num_pipeline_ranks=num_pipeline_ranks,
+                softmax_fp32=softmax_fp32,
             )
 
         pt_model_file = Path(folder) / "consolidated.00.pth"