From d37fbb54ec97a24dfa7342e7ad09c84c5ed910b3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen <patrick.v.platen@gmail.com> Date: Wed, 16 Oct 2024 15:34:03 +0200 Subject: [PATCH] Up --- src/mistral_inference/args.py | 2 +- src/mistral_inference/cache.py | 56 ++++++++++++++++------------ src/mistral_inference/transformer.py | 7 ++-- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/mistral_inference/args.py b/src/mistral_inference/args.py index 4be2470..7ca4f43 100644 --- a/src/mistral_inference/args.py +++ b/src/mistral_inference/args.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, List +from typing import List, Optional from simple_parsing.helpers import Serializable diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index 3cddb4f..22644c4 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -17,12 +17,13 @@ def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[in return n_layers * [sliding_window] else: assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}" - assert n_layers % len(sliding_window) == 0, f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}" + assert ( + n_layers % len(sliding_window) == 0 + ), f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}" num_repeats = n_layers // len(sliding_window) return num_repeats * [w if w is not None else max_seq_len for w in sliding_window] - @dataclass class CacheInputMetadata: # # rope absolute positions @@ -110,12 +111,11 @@ class CacheView: cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)] cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)] - interleaved_k = interleave_list(cache_k, xk) - interleaved_v = interleave_list(cache_v, xv) + interleaved_k = interleave_list(cache_k, list(xk)) + interleaved_v = interleave_list(cache_v, list(xv)) return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0) - @property def max_seq_len(self) -> int: return self.cache_k.shape[1] @@ -150,7 +150,7 @@ class BufferCache: max_seq_len: int, n_kv_heads: int, head_dim: int, - sliding_window: Optional[int] | Optional[List[int]] = None + sliding_window: Optional[int] | Optional[List[int]] = None, ): print(f"yeeeees {sliding_window}") self.max_seq_len = max_seq_len @@ -197,20 +197,24 @@ class BufferCache: def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]: """ - input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3 - --> only cache last 3 tokens in each sequence - - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1] - - cached_elements = [3 | 3 | 2] - --> absolute positions are used for rope - - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4] - --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window - - cache_positions = [2 0 1 | 5 3 4 | 6 7] + input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3 + --> only cache last 3 tokens in each sequence + - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1] + - cached_elements = [3 | 3 | 2] + --> absolute positions are used for rope + - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4] + --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window + - cache_positions = [2 0 1 | 5 3 4 | 6 7] """ metadata: List[CacheInputMetadata] = [] if self.kv_seqlens is None: self.init_kvseqlens(len(seqlens)) - assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?" + + assert self.kv_seqlens is not None + assert len(seqlens) == len( + self.kv_seqlens + ), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?" seqpos = self.kv_seqlens.tolist() assert len(seqlens) > 0, seqlens @@ -220,30 +224,34 @@ class BufferCache: return metadata def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata: - masks = [ - [x >= seqlen - cache_size for x in range(seqlen)] - for seqlen in seqlens - ] + masks = [[x >= seqlen - cache_size for x in range(seqlen)] for seqlen in seqlens] to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool) cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long) - positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long) - batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long) + positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to( + device=self.device, dtype=torch.long + ) + batch_idx = torch.tensor( + sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long + ) cache_positions = positions % cache_size + batch_idx * cache_size first_prefill = seqpos[0] == 0 subsequent_prefill = any(seqlen > 1 for seqlen in seqlens) if first_prefill: - assert all([pos == 0 for pos in seqpos]), (seqpos) + assert all([pos == 0 for pos in seqpos]), seqpos mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size) elif subsequent_prefill: + assert self.kv_seqlens is not None mask = BlockDiagonalMask.from_seqlens( q_seqlen=seqlens, - kv_seqlen=[s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)] + kv_seqlen=[ + s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens) + ], ).make_local_attention_from_bottomright(cache_size) else: mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=seqlens, kv_padding=cache_size, - kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist() + kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist(), ) return CacheInputMetadata( positions=positions, diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py index c4aa897..a53195f 100644 --- a/src/mistral_inference/transformer.py +++ b/src/mistral_inference/transformer.py @@ -150,7 +150,7 @@ class Transformer(ModelBase, LoRALoaderMixin): (num_toks,) = input_ids.shape assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) - input_metadata: List[Union[CacheInputMetadata, SimpleInputMetadata]] + input_metadata: List[CacheInputMetadata] | List[SimpleInputMetadata] if cache is not None: input_metadata = cache.get_input_metadata(seqlens) @@ -173,8 +173,9 @@ class Transformer(ModelBase, LoRALoaderMixin): for local_layer_id, layer in enumerate(self.layers.values()): if cache is not None: assert input_metadata is not None - assert isinstance(input_metadata[local_layer_id], CacheInputMetadata) - cache_view = cache.get_view(local_layer_id, input_metadata[local_layer_id]) + cache_metadata = input_metadata[local_layer_id] + assert isinstance(cache_metadata, CacheInputMetadata) + cache_view = cache.get_view(local_layer_id, cache_metadata) else: cache_view = None h = layer(h, freqs_cis, cache_view) -- GitLab