Skip to content
Snippets Groups Projects
Unverified Commit 6428ccf9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Merge pull request #230 from mistralai/add_layer_wise_rotated_cache

Add per-layer sliding window
parents db6b4223 b9524508
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional
from simple_parsing.helpers import Serializable
......@@ -39,12 +39,18 @@ class TransformerArgs(Serializable):
moe: Optional[MoeArgs] = None
# If this is set, we will load LoRA linear layers instead of linear layers.
lora: Optional[LoraArgs] = None
sliding_window: Optional[int] | Optional[List[int]] = None
_sliding_window: Optional[int] | Optional[List[int]] = None
model_type: str = "transformer"
vision_encoder: Optional[VisionEncoderArgs] = None
def __post_init__(self) -> None:
assert self.model_type == "transformer", self.model_type
assert self.sliding_window is None or self._sliding_window is None
# hack for now so that vLLM is supported correctly
self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window
@dataclass
......
......@@ -10,13 +10,40 @@ from xformers.ops.fmha.attn_bias import ( # type: ignore
)
def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[int] | Optional[List[int]]) -> List[int]:
if sliding_window is None:
return n_layers * [max_seq_len]
elif isinstance(sliding_window, int):
return n_layers * [sliding_window]
else:
assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}"
assert (
n_layers % len(sliding_window) == 0
), f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}"
num_repeats = n_layers // len(sliding_window)
return num_repeats * [w if w is not None else max_seq_len for w in sliding_window]
@dataclass
class CacheInputMetadata:
# # rope absolute positions
# positions: torch.Tensor
# # where tokens should go in the cache
# cache_positions: torch.Tensor
# # if prefill, use block diagonal causal mask
# # else use causal with padded key mask
# prefill: bool
# mask: AttentionBias
# seqlens: List[int]
# rope absolute positions
positions: torch.Tensor
# which elements in the sequences need to be cached
to_cache_mask: torch.Tensor
# how many elements are cached per sequence
cached_elements: torch.Tensor
# where tokens should go in the cache
cache_positions: torch.Tensor
# if prefill, use block diagonal causal mask
# else use causal with padded key mask
prefill: bool
......@@ -29,6 +56,17 @@ def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torc
return [v for pair in zip(l1, l2) for v in pair]
def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor:
assert cache.ndim == 3 # (W, H, D)
position = seqlen % cache.shape[0]
if seqlen < cache.shape[0]:
return cache[:seqlen]
elif position == 0:
return cache
else:
return torch.cat([cache[position:], cache[:position]], dim=0)
class CacheView:
def __init__(
self,
......@@ -50,8 +88,8 @@ class CacheView:
flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim)
flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim)
flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk)
flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv)
flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])
def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
......@@ -69,9 +107,9 @@ class CacheView:
xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens) # type: ignore
assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"
# Retrieve cache
cache_k = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)]
cache_v = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)]
# Order elements in cache by position by unrotating
cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]
interleaved_k = interleave_list(cache_k, list(xk))
interleaved_v = interleave_list(cache_v, list(xv))
......@@ -112,13 +150,22 @@ class BufferCache:
max_seq_len: int,
n_kv_heads: int,
head_dim: int,
sliding_window: Optional[int] | Optional[List[int]] = None,
):
self.max_seq_len = max_seq_len
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_layers = n_layers
self.cache_sizes: List[int] = get_cache_sizes(n_layers, max_seq_len, sliding_window)
assert len(self.cache_sizes) == n_layers, f"Expected {n_layers} cache sizes, got {len(self.cache_sizes)}"
self.cache_k = {}
self.cache_v = {}
for i, cache_size in enumerate(self.cache_sizes):
self.cache_k[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
self.cache_v[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
self.cache_k = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
self.cache_v = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
# holds the valid length for each batch element in the cache
self.kv_seqlens: Optional[torch.Tensor] = None
......@@ -134,11 +181,12 @@ class BufferCache:
@property
def device(self) -> torch.device:
return self.cache_k.device
return self.cache_k[0].device
def to(self, device: torch.device, dtype: torch.dtype) -> "BufferCache":
self.cache_k = self.cache_k.to(device=device, dtype=dtype)
self.cache_v = self.cache_v.to(device=device, dtype=dtype)
for i in range(self.n_layers):
self.cache_k[i] = self.cache_k[i].to(device=device, dtype=dtype)
self.cache_v[i] = self.cache_v[i].to(device=device, dtype=dtype)
return self
......@@ -146,55 +194,69 @@ class BufferCache:
assert self.kv_seqlens is not None
self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)
def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata:
def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]:
"""
Get metadata about cache positions
input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
--> only cache last 3 tokens in each sequence
- to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
- cached_elements = [3 | 3 | 2]
--> absolute positions are used for rope
- positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
--> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
- cache_positions = [2 0 1 | 5 3 4 | 6 7]
"""
metadata: List[CacheInputMetadata] = []
if self.kv_seqlens is None:
self.init_kvseqlens(len(seqlens))
assert isinstance(self.kv_seqlens, torch.Tensor)
assert self.kv_seqlens is not None
assert len(seqlens) == len(
self.kv_seqlens
), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
seqpos = self.kv_seqlens.tolist()
assert len(seqlens) > 0, seqlens
cached_elements = torch.tensor(seqlens, device=self.device, dtype=torch.long)
for cache_size in self.cache_sizes:
metadata.append(self._get_input_metadata_layer(cache_size, seqlens, seqpos))
return metadata
def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata:
masks = [[x >= seqlen - cache_size for x in range(seqlen)] for seqlen in seqlens]
to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(
device=self.device, dtype=torch.long
)
batch_idx = torch.tensor(
sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []),
device=self.device,
dtype=torch.long,
sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long
)
cache_positions = positions + batch_idx * self.max_seq_len
cache_positions = positions % cache_size + batch_idx * cache_size
first_prefill = seqpos[0] == 0
subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
if first_prefill:
assert all([pos == 0 for pos in seqpos]), seqpos
mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len)
mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size)
elif subsequent_prefill:
assert self.kv_seqlens is not None
mask = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens,
kv_seqlen=[
s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
],
).make_local_attention_from_bottomright(self.max_seq_len)
).make_local_attention_from_bottomright(cache_size)
else:
mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=seqlens,
kv_padding=self.max_seq_len,
kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(),
kv_padding=cache_size,
kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist(),
)
return CacheInputMetadata(
positions=positions,
cache_positions=cache_positions,
to_cache_mask=to_cache_mask,
cached_elements=cached_elements,
cache_positions=cache_positions[to_cache_mask],
prefill=first_prefill or subsequent_prefill,
mask=mask,
seqlens=seqlens,
......
......@@ -72,6 +72,7 @@ def generate(
cache_window,
model.args.n_kv_heads,
model.args.head_dim,
model.args.sliding_window,
)
cache.to(device=model.device, dtype=model.dtype)
cache.reset()
......
......@@ -36,6 +36,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
args: TransformerArgs,
pipeline_rank: int = 0,
num_pipeline_ranks: int = 1,
softmax_fp32: bool = True,
):
super().__init__()
self.args = args
......@@ -46,6 +47,8 @@ class Transformer(ModelBase, LoRALoaderMixin):
assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
self.pipeline_rank = pipeline_rank
self.num_pipeline_ranks = num_pipeline_ranks
self.softmax_fp32 = softmax_fp32
# Modules specific to some ranks:
self.tok_embeddings: Optional[nn.Embedding] = None
self.norm: Optional[RMSNorm] = None
......@@ -150,12 +153,12 @@ class Transformer(ModelBase, LoRALoaderMixin):
(num_toks,) = input_ids.shape
assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
input_metadata: Union[CacheInputMetadata, SimpleInputMetadata]
input_metadata: List[CacheInputMetadata] | List[SimpleInputMetadata]
if cache is not None:
input_metadata = cache.get_input_metadata(seqlens)
else:
input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
input_metadata = [SimpleInputMetadata.from_seqlens(seqlens, self.device) for _ in range(len(self.layers))]
if self.pipeline_rank == 0:
assert self.tok_embeddings is not None
......@@ -167,13 +170,15 @@ class Transformer(ModelBase, LoRALoaderMixin):
h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype)
torch.distributed.recv(h, src=self.pipeline_rank - 1)
freqs_cis = self.freqs_cis[input_metadata.positions]
# freqs_cis is always the same for every layer
freqs_cis = self.freqs_cis[input_metadata[0].positions]
for local_layer_id, layer in enumerate(self.layers.values()):
if cache is not None:
assert input_metadata is not None
assert isinstance(input_metadata, CacheInputMetadata)
cache_view = cache.get_view(local_layer_id, input_metadata)
cache_metadata = input_metadata[local_layer_id]
assert isinstance(cache_metadata, CacheInputMetadata)
cache_view = cache.get_view(local_layer_id, cache_metadata)
else:
cache_view = None
h = layer(h, freqs_cis, cache_view)
......@@ -205,7 +210,11 @@ class Transformer(ModelBase, LoRALoaderMixin):
outs = self.output(h)
if self.num_pipeline_ranks > 1:
torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
return outs.float()
if self.softmax_fp32:
return outs.float()
else:
return outs
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None:
state_to_load = {}
......@@ -257,6 +266,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
num_pipeline_ranks: int = 1,
device: Union[torch.device, str] = "cuda",
dtype: Optional[torch.dtype] = None,
softmax_fp32: bool = True,
) -> "Transformer":
with open(Path(folder) / "params.json", "r") as f:
model_args = TransformerArgs.from_dict(json.load(f))
......@@ -270,6 +280,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
model_args,
pipeline_rank=pipeline_rank,
num_pipeline_ranks=num_pipeline_ranks,
softmax_fp32=softmax_fp32,
)
pt_model_file = Path(folder) / "consolidated.00.pth"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment