diff --git a/src/mistral_inference/args.py b/src/mistral_inference/args.py
index 4be24709f3a1d1c724c34ac407d975d62741c8ba..7ca4f43801aa13b5dd6a08c823d7d6ef2a186384 100644
--- a/src/mistral_inference/args.py
+++ b/src/mistral_inference/args.py
@@ -1,5 +1,5 @@
 from dataclasses import dataclass
-from typing import Optional, List
+from typing import List, Optional
 
 from simple_parsing.helpers import Serializable
 
diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py
index 3cddb4fa09a44990870b1a8cfba7b31a8264ae40..22644c4bea4bbe0da9da6f4596d815eb5ada2d75 100644
--- a/src/mistral_inference/cache.py
+++ b/src/mistral_inference/cache.py
@@ -17,12 +17,13 @@ def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[in
         return n_layers * [sliding_window]
     else:
         assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}"
-        assert n_layers % len(sliding_window) == 0, f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}"
+        assert (
+            n_layers % len(sliding_window) == 0
+        ), f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}"
         num_repeats = n_layers // len(sliding_window)
         return num_repeats * [w if w is not None else max_seq_len for w in sliding_window]
 
 
-
 @dataclass
 class CacheInputMetadata:
     # # rope absolute positions
@@ -110,12 +111,11 @@ class CacheView:
         cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
         cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]
 
-        interleaved_k = interleave_list(cache_k, xk)
-        interleaved_v = interleave_list(cache_v, xv)
+        interleaved_k = interleave_list(cache_k, list(xk))
+        interleaved_v = interleave_list(cache_v, list(xv))
 
         return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0)
 
-
     @property
     def max_seq_len(self) -> int:
         return self.cache_k.shape[1]
@@ -150,7 +150,7 @@ class BufferCache:
         max_seq_len: int,
         n_kv_heads: int,
         head_dim: int,
-        sliding_window: Optional[int] | Optional[List[int]] = None
+        sliding_window: Optional[int] | Optional[List[int]] = None,
     ):
         print(f"yeeeees {sliding_window}")
         self.max_seq_len = max_seq_len
@@ -197,20 +197,24 @@ class BufferCache:
 
     def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]:
         """
-            input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
-            --> only cache last 3 tokens in each sequence
-            - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
-            - cached_elements = [3 | 3 | 2]
-            --> absolute positions are used for rope
-            - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
-            --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
-            - cache_positions = [2 0 1 | 5 3 4 | 6 7]
+        input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
+        --> only cache last 3 tokens in each sequence
+        - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
+        - cached_elements = [3 | 3 | 2]
+        --> absolute positions are used for rope
+        - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
+        --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
+        - cache_positions = [2 0 1 | 5 3 4 | 6 7]
         """
         metadata: List[CacheInputMetadata] = []
 
         if self.kv_seqlens is None:
             self.init_kvseqlens(len(seqlens))
-        assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
+
+        assert self.kv_seqlens is not None
+        assert len(seqlens) == len(
+            self.kv_seqlens
+        ), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
         seqpos = self.kv_seqlens.tolist()
         assert len(seqlens) > 0, seqlens
 
@@ -220,30 +224,34 @@ class BufferCache:
         return metadata
 
     def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata:
-        masks = [
-            [x >= seqlen - cache_size for x in range(seqlen)]
-            for seqlen in seqlens
-        ]
+        masks = [[x >= seqlen - cache_size for x in range(seqlen)] for seqlen in seqlens]
         to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
         cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
-        positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long)
-        batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long)
+        positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(
+            device=self.device, dtype=torch.long
+        )
+        batch_idx = torch.tensor(
+            sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long
+        )
         cache_positions = positions % cache_size + batch_idx * cache_size
         first_prefill = seqpos[0] == 0
         subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
         if first_prefill:
-            assert all([pos == 0 for pos in seqpos]), (seqpos)
+            assert all([pos == 0 for pos in seqpos]), seqpos
             mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size)
         elif subsequent_prefill:
+            assert self.kv_seqlens is not None
             mask = BlockDiagonalMask.from_seqlens(
                 q_seqlen=seqlens,
-                kv_seqlen=[s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)]
+                kv_seqlen=[
+                    s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
+                ],
             ).make_local_attention_from_bottomright(cache_size)
         else:
             mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
                 q_seqlen=seqlens,
                 kv_padding=cache_size,
-                kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist()
+                kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist(),
             )
         return CacheInputMetadata(
             positions=positions,
diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py
index c4aa897341a394aaff3be0fb1376e3848d4e1f18..a53195fcf5af03c133bd10d1c676b11f370fb38c 100644
--- a/src/mistral_inference/transformer.py
+++ b/src/mistral_inference/transformer.py
@@ -150,7 +150,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
         (num_toks,) = input_ids.shape
         assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
 
-        input_metadata: List[Union[CacheInputMetadata, SimpleInputMetadata]]
+        input_metadata: List[CacheInputMetadata] | List[SimpleInputMetadata]
 
         if cache is not None:
             input_metadata = cache.get_input_metadata(seqlens)
@@ -173,8 +173,9 @@ class Transformer(ModelBase, LoRALoaderMixin):
         for local_layer_id, layer in enumerate(self.layers.values()):
             if cache is not None:
                 assert input_metadata is not None
-                assert isinstance(input_metadata[local_layer_id], CacheInputMetadata)
-                cache_view = cache.get_view(local_layer_id, input_metadata[local_layer_id])
+                cache_metadata = input_metadata[local_layer_id]
+                assert isinstance(cache_metadata, CacheInputMetadata)
+                cache_view = cache.get_view(local_layer_id, cache_metadata)
             else:
                 cache_view = None
             h = layer(h, freqs_cis, cache_view)