Skip to content
Snippets Groups Projects
Commit 93465219 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

WIP

parent db6b4223
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass
from typing import Optional
from typing import Optional, List
from simple_parsing.helpers import Serializable
......@@ -39,12 +39,18 @@ class TransformerArgs(Serializable):
moe: Optional[MoeArgs] = None
# If this is set, we will load LoRA linear layers instead of linear layers.
lora: Optional[LoraArgs] = None
sliding_window: Optional[int] | Optional[List[int]] = None
_sliding_window: Optional[int] | Optional[List[int]] = None
model_type: str = "transformer"
vision_encoder: Optional[VisionEncoderArgs] = None
def __post_init__(self) -> None:
assert self.model_type == "transformer", self.model_type
assert self.sliding_window is None or self._sliding_window is None
# hack for now so that vLLM is supported correctly
self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window
@dataclass
......
......@@ -10,13 +10,39 @@ from xformers.ops.fmha.attn_bias import ( # type: ignore
)
def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[int] | Optional[List[int]]) -> List[int]:
if sliding_window is None:
return n_layers * [max_seq_len]
elif isinstance(sliding_window, int):
return n_layers * [sliding_window]
else:
assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}"
assert n_layers % len(sliding_window) == 0, f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}"
num_repeats = n_layers // len(sliding_window)
return num_repeats * [w if w is not None else max_seq_len for w in sliding_window]
@dataclass
class CacheInputMetadata:
# # rope absolute positions
# positions: torch.Tensor
# # where tokens should go in the cache
# cache_positions: torch.Tensor
# # if prefill, use block diagonal causal mask
# # else use causal with padded key mask
# prefill: bool
# mask: AttentionBias
# seqlens: List[int]
# rope absolute positions
positions: torch.Tensor
# which elements in the sequences need to be cached
to_cache_mask: torch.Tensor
# how many elements are cached per sequence
cached_elements: torch.Tensor
# where tokens should go in the cache
cache_positions: torch.Tensor
# if prefill, use block diagonal causal mask
# else use causal with padded key mask
prefill: bool
......@@ -29,6 +55,17 @@ def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torc
return [v for pair in zip(l1, l2) for v in pair]
def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor:
assert cache.ndim == 3 # (W, H, D)
position = seqlen % cache.shape[0]
if seqlen < cache.shape[0]:
return cache[:seqlen]
elif position == 0:
return cache
else:
return torch.cat([cache[position:], cache[:position]], dim=0)
class CacheView:
def __init__(
self,
......@@ -50,8 +87,8 @@ class CacheView:
flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim)
flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim)
flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk)
flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv)
flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])
def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
......@@ -69,15 +106,16 @@ class CacheView:
xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens) # type: ignore
assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"
# Retrieve cache
cache_k = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)]
cache_v = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)]
# Order elements in cache by position by unrotating
cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]
interleaved_k = interleave_list(cache_k, list(xk))
interleaved_v = interleave_list(cache_v, list(xv))
interleaved_k = interleave_list(cache_k, xk)
interleaved_v = interleave_list(cache_v, xv)
return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0)
@property
def max_seq_len(self) -> int:
return self.cache_k.shape[1]
......@@ -112,13 +150,23 @@ class BufferCache:
max_seq_len: int,
n_kv_heads: int,
head_dim: int,
sliding_window: Optional[int] | Optional[List[int]] = None
):
print(f"yeeeees {sliding_window}")
self.max_seq_len = max_seq_len
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_layers = n_layers
self.cache_sizes: List[int] = get_cache_sizes(n_layers, max_seq_len, sliding_window)
assert len(self.cache_sizes) == n_layers, f"Expected {n_layers} cache sizes, got {len(self.cache_sizes)}"
self.cache_k = {}
self.cache_v = {}
for i, cache_size in enumerate(self.cache_sizes):
self.cache_k[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
self.cache_v[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
self.cache_k = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
self.cache_v = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
# holds the valid length for each batch element in the cache
self.kv_seqlens: Optional[torch.Tensor] = None
......@@ -134,11 +182,12 @@ class BufferCache:
@property
def device(self) -> torch.device:
return self.cache_k.device
return self.cache_k[0].device
def to(self, device: torch.device, dtype: torch.dtype) -> "BufferCache":
self.cache_k = self.cache_k.to(device=device, dtype=dtype)
self.cache_v = self.cache_v.to(device=device, dtype=dtype)
for i in range(self.n_layers):
self.cache_k[i] = self.cache_k[i].to(device=device, dtype=dtype)
self.cache_v[i] = self.cache_v[i].to(device=device, dtype=dtype)
return self
......@@ -146,55 +195,61 @@ class BufferCache:
assert self.kv_seqlens is not None
self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)
def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata:
def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]:
"""
Get metadata about cache positions
input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
--> only cache last 3 tokens in each sequence
- to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
- cached_elements = [3 | 3 | 2]
--> absolute positions are used for rope
- positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
--> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
- cache_positions = [2 0 1 | 5 3 4 | 6 7]
"""
metadata: List[CacheInputMetadata] = []
if self.kv_seqlens is None:
self.init_kvseqlens(len(seqlens))
assert isinstance(self.kv_seqlens, torch.Tensor)
assert len(seqlens) == len(
self.kv_seqlens
), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
seqpos = self.kv_seqlens.tolist()
assert len(seqlens) > 0, seqlens
cached_elements = torch.tensor(seqlens, device=self.device, dtype=torch.long)
positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(
device=self.device, dtype=torch.long
)
batch_idx = torch.tensor(
sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []),
device=self.device,
dtype=torch.long,
)
cache_positions = positions + batch_idx * self.max_seq_len
for cache_size in self.cache_sizes:
metadata.append(self._get_input_metadata_layer(cache_size, seqlens, seqpos))
return metadata
def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata:
masks = [
[x >= seqlen - cache_size for x in range(seqlen)]
for seqlen in seqlens
]
to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long)
batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long)
cache_positions = positions % cache_size + batch_idx * cache_size
first_prefill = seqpos[0] == 0
subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
if first_prefill:
assert all([pos == 0 for pos in seqpos]), seqpos
mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len)
assert all([pos == 0 for pos in seqpos]), (seqpos)
mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size)
elif subsequent_prefill:
mask = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens,
kv_seqlen=[
s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
],
).make_local_attention_from_bottomright(self.max_seq_len)
kv_seqlen=[s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)]
).make_local_attention_from_bottomright(cache_size)
else:
mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=seqlens,
kv_padding=self.max_seq_len,
kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(),
kv_padding=cache_size,
kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist()
)
return CacheInputMetadata(
positions=positions,
cache_positions=cache_positions,
to_cache_mask=to_cache_mask,
cached_elements=cached_elements,
cache_positions=cache_positions[to_cache_mask],
prefill=first_prefill or subsequent_prefill,
mask=mask,
seqlens=seqlens,
......
......@@ -72,6 +72,7 @@ def generate(
cache_window,
model.args.n_kv_heads,
model.args.head_dim,
model.args.sliding_window,
)
cache.to(device=model.device, dtype=model.dtype)
cache.reset()
......
......@@ -150,12 +150,12 @@ class Transformer(ModelBase, LoRALoaderMixin):
(num_toks,) = input_ids.shape
assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
input_metadata: Union[CacheInputMetadata, SimpleInputMetadata]
input_metadata: List[Union[CacheInputMetadata, SimpleInputMetadata]]
if cache is not None:
input_metadata = cache.get_input_metadata(seqlens)
else:
input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
input_metadata = [SimpleInputMetadata.from_seqlens(seqlens, self.device) for _ in range(len(self.layers))]
if self.pipeline_rank == 0:
assert self.tok_embeddings is not None
......@@ -167,13 +167,14 @@ class Transformer(ModelBase, LoRALoaderMixin):
h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype)
torch.distributed.recv(h, src=self.pipeline_rank - 1)
freqs_cis = self.freqs_cis[input_metadata.positions]
# freqs_cis is always the same for every layer
freqs_cis = self.freqs_cis[input_metadata[0].positions]
for local_layer_id, layer in enumerate(self.layers.values()):
if cache is not None:
assert input_metadata is not None
assert isinstance(input_metadata, CacheInputMetadata)
cache_view = cache.get_view(local_layer_id, input_metadata)
assert isinstance(input_metadata[local_layer_id], CacheInputMetadata)
cache_view = cache.get_view(local_layer_id, input_metadata[local_layer_id])
else:
cache_view = None
h = layer(h, freqs_cis, cache_view)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment