diff --git a/README.md b/README.md
index 79ec2f4e6bdecc7b363a2ad36bff2ed083b5b2de..090dc45af7e51a9bebf2f1a57140001c8526e682 100644
--- a/README.md
+++ b/README.md
@@ -155,7 +155,7 @@ You can continue chatting afterwards, *e.g.* with *"Translate it to Python"*.
 - *Instruction Following*:
 
 ```py
-from mistral_inference.model import Transformer
+from mistral_inference.transformer import Transformer
 from mistral_inference.generate import generate
 
 from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
@@ -228,7 +228,7 @@ pip install --upgrade mistral-common
 You can simulate a code completion in-filling as follows.
 
 ```py
-from mistral_inference.model import Transformer
+from mistral_inference.transformer import Transformer
 from mistral_inference.generate import generate
 from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
 from mistral_common.tokens.instruct.request import FIMRequest
diff --git a/moe_one_file_ref.py b/moe_one_file_ref.py
index aa9c4bb09ee7629f58d1b90db6ad7dfd7daed38c..542388e3c0b2e2d8edcc2bcf2edfc97a7b30fc11 100644
--- a/moe_one_file_ref.py
+++ b/moe_one_file_ref.py
@@ -22,7 +22,7 @@ class MoeArgs(Serializable):
 
 
 @dataclass
-class ModelArgs(Serializable):
+class TransformerArgs(Serializable):
     dim: int
     n_layers: int
     head_dim: int
@@ -80,7 +80,7 @@ def apply_rotary_emb(
 
 
 class Attention(nn.Module):
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: TransformerArgs):
         super().__init__()
         self.args = args
 
@@ -144,9 +144,7 @@ class Attention(nn.Module):
         xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
 
         # Update cache
-        scatter_pos = positions[None, :, None, None].repeat(
-            bsz, 1, self.n_kv_heads, self.args.head_dim
-        )
+        scatter_pos = positions[None, :, None, None].repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
         cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk)
         cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv)
 
@@ -179,7 +177,7 @@ class Attention(nn.Module):
 
 
 class FeedForward(nn.Module):
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: TransformerArgs):
         super().__init__()
         self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
         self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
@@ -214,9 +212,7 @@ class MoeLayer(nn.Module):
     def forward(self, inputs: torch.Tensor):
         inputs_squashed = inputs.view(-1, inputs.shape[-1])
         gate_logits = self.gate(inputs_squashed)
-        weights, selected_experts = torch.topk(
-            gate_logits, self.args.num_experts_per_tok
-        )
+        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
         weights = nn.functional.softmax(
             weights,
             dim=1,
@@ -225,14 +221,12 @@ class MoeLayer(nn.Module):
         results = torch.zeros_like(inputs_squashed)
         for i, expert in enumerate(self.experts):
             batch_idx, nth_expert = torch.where(selected_experts == i)
-            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
-                inputs_squashed[batch_idx]
-            )
+            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs_squashed[batch_idx])
         return results.view_as(inputs)
 
 
 class TransformerBlock(nn.Module):
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: TransformerArgs):
         super().__init__()
         self.n_heads = args.n_heads
         self.dim = args.dim
@@ -270,7 +264,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
 class Transformer(nn.Module):
     def __init__(
         self,
-        args: ModelArgs,
+        args: TransformerArgs,
         pipeline_rank: int = 0,
         num_pipeline_ranks: int = 1,
     ):
@@ -316,13 +310,9 @@ class Transformer(nn.Module):
         # from the module's  dtype means we cannot register it as a buffer
         if self._precomputed_freqs_cis is None:
             theta = self.args.rope_theta or 1000000.0
-            self._precomputed_freqs_cis = precompute_freqs_cis(
-                self.args.head_dim, 128_000, theta
-            )
+            self._precomputed_freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta)
         if self._precomputed_freqs_cis.device != self.device:
-            self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(
-                device=self.device
-            )
+            self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device)
         return self._precomputed_freqs_cis
 
     def forward(
@@ -341,9 +331,7 @@ class Transformer(nn.Module):
             assert h.shape == (bsz, seqlen, self.args.dim)
             assert h.dtype == self.dtype
         else:
-            h = torch.empty(
-                bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype
-            )
+            h = torch.empty(bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype)
             torch.distributed.recv(h, src=self.pipeline_rank - 1)
 
         mask: Optional[torch.Tensor] = None
@@ -361,9 +349,7 @@ class Transformer(nn.Module):
 
         if self.pipeline_rank < self.num_pipeline_ranks - 1:
             torch.distributed.send(h, dst=self.pipeline_rank + 1)
-            outs = torch.empty(
-                *h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype
-            )
+            outs = torch.empty(*h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype)
         else:
             assert self.output is not None
             assert self.norm is not None
@@ -422,7 +408,7 @@ class Transformer(nn.Module):
         dtype=torch.float16,
     ) -> "Transformer":
         with open(folder / "params.json", "r") as f:
-            model_args = ModelArgs.from_dict(json.load(f))
+            model_args = TransformerArgs.from_dict(json.load(f))
         model_args.max_batch_size = max_batch_size
         model_args.max_seq_len = max_seq_len
         if num_pipeline_ranks > 1:
@@ -457,9 +443,7 @@ class Transformer(nn.Module):
 
 
 def load_tokenizer(model_path: Path) -> MistralTokenizer:
-    tokenizer = [
-        f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")
-    ]
+    tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")]
     assert (
         len(tokenizer) == 1
     ), f"Multiple tokenizers {', '.join(tokenizer)} found in `model_path`, make sure to only have one tokenizer"
@@ -470,12 +454,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer:
 
 
 @torch.no_grad()
-def generate(
-    prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int
-):
-    encoded_prompts = [
-        tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts
-    ]
+def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int):
+    encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts]
     prompt_lens = [len(x) for x in encoded_prompts]
     min_prompt_len = min(prompt_lens)
     max_prompt_len = max(prompt_lens)
@@ -498,23 +478,17 @@ def generate(
     # decode
     generated = []
     all_logprobs = [
-        logprobs[:, :-1, :]
-        .gather(2, input_tokens[:, 1:min_prompt_len, None])
-        .squeeze(-1),
+        logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1),
     ]
     for cur_pos in range(min_prompt_len, max_tokens):
         next_token = torch.argmax(logprobs[:, -1, :], dim=-1)
         if cur_pos < input_mask.shape[1]:
-            next_token = torch.where(
-                input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token
-            )
+            next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token)
         all_logprobs.append(
             logprobs[:, -1, :].gather(1, next_token[:, None]),
         )
         generated.append(next_token[:, None])
-        logits = model.forward(
-            next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)
-        )
+        logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token))
         logprobs = nn.functional.log_softmax(logits, dim=-1)
 
     all_logprobs_merged = torch.cat(all_logprobs, 1)
diff --git a/one_file_ref.py b/one_file_ref.py
index b654e78196fd1885f2608b9f4dc660bc33746cff..a848d7392f04404e36e4cf1782a1c60cab423ba9 100644
--- a/one_file_ref.py
+++ b/one_file_ref.py
@@ -14,7 +14,7 @@ from torch import nn
 
 
 @dataclass
-class ModelArgs(Serializable):
+class TransformerArgs(Serializable):
     dim: int
     n_layers: int
     head_dim: int
@@ -31,9 +31,7 @@ class ModelArgs(Serializable):
     max_batch_size: int = 0
 
 
-def repeat_kv(
-    keys: torch.Tensor, values: torch.Tensor, repeats: int
-) -> Tuple[torch.Tensor]:
+def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int) -> Tuple[torch.Tensor]:
     keys = torch.repeat_interleave(keys, repeats=repeats, dim=2)
     values = torch.repeat_interleave(values, repeats=repeats, dim=2)
     return keys, values
@@ -68,7 +66,7 @@ def apply_rotary_emb(
 
 
 class Attention(nn.Module):
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: TransformerArgs):
         super().__init__()
         self.args = args
 
@@ -118,9 +116,7 @@ class Attention(nn.Module):
         xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
 
         # cache
-        scatter_pos = positions[None, :, None, None].repeat(
-            bsz, 1, self.n_kv_heads, self.args.head_dim
-        )
+        scatter_pos = positions[None, :, None, None].repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
         self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk)
         self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv)
 
@@ -152,7 +148,7 @@ class Attention(nn.Module):
 
 
 class FeedForward(nn.Module):
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: TransformerArgs):
         super().__init__()
 
         self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
@@ -178,7 +174,7 @@ class RMSNorm(torch.nn.Module):
 
 
 class TransformerBlock(nn.Module):
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: TransformerArgs):
         super().__init__()
         self.n_heads = args.n_heads
         self.dim = args.dim
@@ -210,7 +206,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
 
 
 class Transformer(nn.Module):
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: TransformerArgs):
         super().__init__()
         self.args = args
         self.vocab_size = args.vocab_size
@@ -219,18 +215,14 @@ class Transformer(nn.Module):
 
         self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
 
-        self.layers = torch.nn.ModuleList(
-            [TransformerBlock(args=args) for _ in range(args.n_layers)]
-        )
+        self.layers = torch.nn.ModuleList([TransformerBlock(args=args) for _ in range(args.n_layers)])
 
         self.norm = RMSNorm(args.dim, eps=args.norm_eps)
 
         self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
 
         theta = self.args.rope_theta or 1000000.0
-        self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta).to(
-            "cuda"
-        )
+        self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta).to("cuda")
 
     def forward(
         self,
@@ -259,11 +251,9 @@ class Transformer(nn.Module):
         return self.output(self.norm(h)).float()
 
     @staticmethod
-    def from_folder(
-        folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16
-    ):
+    def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16):
         with open(Path(folder) / "params.json", "r") as f:
-            model_args = ModelArgs.from_dict(json.load(f))
+            model_args = TransformerArgs.from_dict(json.load(f))
         model_args.max_batch_size = max_batch_size
 
         model = Transformer(model_args)
@@ -288,9 +278,7 @@ class Transformer(nn.Module):
 
 
 def load_tokenizer(model_path: Path) -> MistralTokenizer:
-    tokenizer = [
-        f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")
-    ]
+    tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")]
     assert (
         len(tokenizer) > 0
     ), f"No tokenizer found in {model_path}, make sure to place a `tokenizer.model.[v1,v2,v3]` file in {model_path}."
@@ -304,12 +292,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer:
 
 
 @torch.no_grad()
-def generate(
-    prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int
-):
-    encoded_prompts = [
-        tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts
-    ]
+def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int):
+    encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts]
     prompt_lens = [len(x) for x in encoded_prompts]
     min_prompt_len = min(prompt_lens)
     max_prompt_len = max(prompt_lens)
@@ -333,24 +317,18 @@ def generate(
     # decode
     generated = []
     all_logprobs = [
-        logprobs[:, :-1, :]
-        .gather(2, input_tokens[:, 1:min_prompt_len, None])
-        .squeeze(-1),
+        logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1),
     ]
     cur_pos = min_prompt_len
     for _ in range(max_tokens):
         next_token = torch.argmax(logprobs[:, -1, :], dim=-1)
         if cur_pos < input_mask.shape[1]:
-            next_token = torch.where(
-                input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token
-            )
+            next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token)
         all_logprobs.append(
             logprobs[:, -1, :].gather(1, next_token[:, None]),
         )
         generated.append(next_token[:, None])
-        logits = model.forward(
-            next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)
-        )
+        logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token))
         logprobs = nn.functional.log_softmax(logits, dim=-1)
         cur_pos += 1
 
diff --git a/src/mistral_inference/args.py b/src/mistral_inference/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbf9f810dae61f6de0093e8b70f183514d6b4fd9
--- /dev/null
+++ b/src/mistral_inference/args.py
@@ -0,0 +1,49 @@
+from dataclasses import dataclass
+from typing import Optional
+
+from simple_parsing.helpers import Serializable
+
+from mistral_inference.lora import LoraArgs
+from mistral_inference.moe import MoeArgs
+
+
+@dataclass
+class TransformerArgs(Serializable):
+    dim: int
+    n_layers: int
+    head_dim: int
+    hidden_dim: int
+    n_heads: int
+    n_kv_heads: int
+    norm_eps: float
+    vocab_size: int
+
+    max_batch_size: int = 0
+
+    # For rotary embeddings. If not set, will be inferred
+    rope_theta: Optional[float] = None
+    # If this is set, we will use MoE layers instead of dense layers.
+    moe: Optional[MoeArgs] = None
+    # If this is set, we will load LoRA linear layers instead of linear layers.
+    lora: Optional[LoraArgs] = None
+    model_type: str = "transformer"
+
+    def __post_init__(self):
+        assert self.model_type == "transformer", self.model_type
+
+
+@dataclass
+class MambaArgs(Serializable):
+    dim: int
+    n_layers: int
+    vocab_size: int
+    n_groups: int
+    rms_norm: bool
+    residual_in_fp32: bool
+    fused_add_norm: bool
+    pad_vocab_size_multiple: int
+    tie_embeddings: bool
+    model_type: str = "mamba"
+
+    def __post_init__(self):
+        assert self.model_type == "mamba", self.model_type
diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py
index a3b47253e831b11f61e882f763a97a251db0d136..93cfb1c102a4cb2199d408866a12ed75a919f4f9 100644
--- a/src/mistral_inference/cache.py
+++ b/src/mistral_inference/cache.py
@@ -24,9 +24,7 @@ class CacheInputMetadata:
     seqlens: List[int]
 
 
-def interleave_list(
-    l1: List[torch.Tensor], l2: List[torch.Tensor]
-) -> List[torch.Tensor]:
+def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torch.Tensor]:
     assert len(l1) == len(l2)
     return [v for pair in zip(l1, l2) for v in pair]
 
@@ -55,9 +53,7 @@ class CacheView:
         flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk)
         flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv)
 
-    def interleave_kv(
-        self, xk: torch.Tensor, xv: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         """
         This is a naive implementation and not optimized for speed.
         """
@@ -71,17 +67,11 @@ class CacheView:
         # Make it a list of [(T, H, D)]
         xk: Tuple[torch.Tensor] = torch.split(xk, self.metadata.seqlens)  # type: ignore
         xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens)  # type: ignore
-        assert len(xk) == len(
-            self.kv_seqlens
-        ), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"
+        assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"
 
         # Retrieve cache
-        cache_k = [
-            cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)
-        ]
-        cache_v = [
-            cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)
-        ]
+        cache_k = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)]
+        cache_v = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)]
 
         interleaved_k = interleave_list(cache_k, list(xk))
         interleaved_v = interleave_list(cache_v, list(xv))
@@ -127,28 +117,20 @@ class BufferCache:
         self.n_kv_heads = n_kv_heads
         self.head_dim = head_dim
 
-        self.cache_k = torch.empty(
-            (n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim)
-        )
-        self.cache_v = torch.empty(
-            (n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim)
-        )
+        self.cache_k = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
+        self.cache_v = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
         # holds the valid length for each batch element in the cache
         self.kv_seqlens: Optional[torch.Tensor] = None
 
     def get_view(self, layer_id: int, metadata: CacheInputMetadata) -> CacheView:
         assert self.kv_seqlens is not None
-        return CacheView(
-            self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens
-        )
+        return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens)
 
     def reset(self) -> None:
         self.kv_seqlens = None
 
     def init_kvseqlens(self, batch_size: int) -> None:
-        self.kv_seqlens = torch.zeros(
-            (batch_size,), device=self.device, dtype=torch.long
-        )
+        self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
 
     @property
     def device(self) -> torch.device:
@@ -180,9 +162,9 @@ class BufferCache:
         assert len(seqlens) > 0, seqlens
         cached_elements = torch.tensor(seqlens, device=self.device, dtype=torch.long)
 
-        positions = torch.cat(
-            [torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]
-        ).to(device=self.device, dtype=torch.long)
+        positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(
+            device=self.device, dtype=torch.long
+        )
 
         batch_idx = torch.tensor(
             sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []),
@@ -195,24 +177,19 @@ class BufferCache:
         subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
         if first_prefill:
             assert all([pos == 0 for pos in seqpos]), seqpos
-            mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(
-                self.max_seq_len
-            )
+            mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len)
         elif subsequent_prefill:
             mask = BlockDiagonalMask.from_seqlens(
                 q_seqlen=seqlens,
                 kv_seqlen=[
-                    s + cached_s.clamp(max=self.max_seq_len).item()
-                    for (s, cached_s) in zip(seqlens, self.kv_seqlens)
+                    s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
                 ],
             ).make_local_attention_from_bottomright(self.max_seq_len)
         else:
             mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
                 q_seqlen=seqlens,
                 kv_padding=self.max_seq_len,
-                kv_seqlen=(self.kv_seqlens + cached_elements)
-                .clamp(max=self.max_seq_len)
-                .tolist(),
+                kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(),
             )
 
         return CacheInputMetadata(
diff --git a/src/mistral_inference/generate.py b/src/mistral_inference/generate.py
index c7f0bca14142385e9dbc75ca6e16ee883d17a237..c9e35c5dbf6a1c0e14f00692880145f62c2f7744 100644
--- a/src/mistral_inference/generate.py
+++ b/src/mistral_inference/generate.py
@@ -3,7 +3,40 @@ from typing import List, Optional, Tuple
 import torch
 
 from mistral_inference.cache import BufferCache
-from mistral_inference.model import Transformer
+from mistral_inference.mamba import Mamba
+from mistral_inference.transformer import Transformer
+
+
+@torch.inference_mode()
+def generate_mamba(
+    encoded_prompts: List[List[int]],
+    model: Mamba,
+    *,
+    max_tokens: int,
+    temperature: float,
+    chunk_size: Optional[int] = None,
+    eos_id: Optional[int] = None,
+) -> Tuple[List[List[int]], List[List[float]]]:
+    input_ids = torch.tensor(encoded_prompts, device=model.device)
+    output = model.model.generate(
+        input_ids=input_ids,
+        max_length=input_ids.shape[-1] + max_tokens,
+        cg=True,
+        return_dict_in_generate=True,
+        output_scores=True,
+        enable_timing=False,
+        eos_token_id=eos_id,
+        temperature=temperature,
+        top_p=0.8,
+    )
+    generated_tokens = output.sequences[:, input_ids.shape[-1] :].tolist()
+
+    _logprobs: List[List[float]] = [[] for _ in range(len(generated_tokens))]
+    for seq_idx, batch_score in enumerate(output.scores):
+        for batch_idx, score in enumerate(batch_score.tolist()):
+            _logprobs[batch_idx].append(score[generated_tokens[batch_idx][seq_idx]])
+
+    return generated_tokens, _logprobs
 
 
 @torch.inference_mode()
@@ -14,7 +47,7 @@ def generate(
     max_tokens: int,
     temperature: float,
     chunk_size: Optional[int] = None,
-    eos_id: Optional[int] = None
+    eos_id: Optional[int] = None,
 ) -> Tuple[List[List[int]], List[List[float]]]:
     model = model.eval()
     B, V = len(encoded_prompts), model.args.vocab_size
@@ -57,26 +90,16 @@ def generate(
             # Pass > 1
             last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
             for i_seq in range(B):
-                logprobs[i_seq].append(
-                    last_token_logits[i_seq, prompt_chunks[i_seq][0]].item()
-                )
+                logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item())
 
         offset = 0
         for i_seq, sequence in enumerate(prompt_chunks):
-            logprobs[i_seq].extend(
-                [
-                    logits[offset + i, sequence[i + 1]].item()
-                    for i in range(len(sequence) - 1)
-                ]
-            )
+            logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)])
             offset += len(sequence)
 
         last_token_prelogits = prelogits.index_select(
             0,
-            torch.tensor(
-                [len(p) for p in prompt_chunks], device=prelogits.device
-            ).cumsum(dim=0)
-            - 1,
+            torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1,
         )
         assert last_token_prelogits.shape == (B, V)
 
diff --git a/src/mistral_inference/lora.py b/src/mistral_inference/lora.py
index 2ab978cf06c20e06727e4efdd868626fa04fe0c5..30924290038d4a7d9803000032f620d955a01dd5 100644
--- a/src/mistral_inference/lora.py
+++ b/src/mistral_inference/lora.py
@@ -1,7 +1,7 @@
 import logging
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Dict, NamedTuple, Union
+from typing import Any, Dict, NamedTuple, Union
 
 import safetensors.torch
 import torch
@@ -14,7 +14,7 @@ class LoraArgs(Serializable):
     rank: int
     scaling: float
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         assert self.rank > 0
         assert self.scaling > 0.0
 
@@ -63,16 +63,17 @@ class LoRALinear(nn.Module):
         self.linear = nn.Linear(self.in_features, self.out_features, bias=self.bias)
 
         # make sure no LoRA weights are marked as "missing" in load_state_dict
-        def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple):
+        def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple) -> None:
             incompatible_keys.missing_keys[:] = []  # type: ignore
 
         self.register_load_state_dict_post_hook(ignore_missing_keys)
 
-    def forward(self, x: torch.Tensor):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
         lora = self.lora_B(self.lora_A(x))
-        return self.linear(x) + lora * self.scaling
+        result: torch.Tensor = self.linear(x) + lora * self.scaling
+        return result
 
-    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+    def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: str, *args, **kwargs) -> None:  # type: ignore[no-untyped-def]
         key_name = prefix + "weight"
 
         # full checkpoint
@@ -82,18 +83,14 @@ class LoRALinear(nn.Module):
             # load frozen weights
             state_dict = {
                 "linear.weight": w_ref,
-                "lora_A.weight": torch.zeros_like(
-                    self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype
-                ),
-                "lora_B.weight": torch.zeros_like(
-                    self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype
-                ),
+                "lora_A.weight": torch.zeros_like(self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype),
+                "lora_B.weight": torch.zeros_like(self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype),
             }
             self.load_state_dict(state_dict, assign=True, strict=True)
 
 
 class LoRALoaderMixin:
-    def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0):
+    def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0) -> None:
         """Loads LoRA checkpoint"""
 
         lora_path = Path(lora_path)
@@ -103,47 +100,39 @@ class LoRALoaderMixin:
 
         self._load_lora_state_dict(state_dict, scaling=scaling)
 
-    def _load_lora_state_dict(
-        self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0
-    ):
+    def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0) -> None:
         """Loads LoRA state_dict"""
-
         lora_dtypes = set([p.dtype for p in lora_state_dict.values()])
         assert (
             len(lora_dtypes) == 1
         ), f"LoRA weights have multiple different dtypes {lora_dtypes}. All weights need to have the same dtype"
         lora_dtype = lora_dtypes.pop()
-        assert (
-            lora_dtype == self.dtype
-        ), f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}"
+        assert lora_dtype == self.dtype, f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}"  # type: ignore[attr-defined]
         assert all("lora" in key for key in lora_state_dict.keys())
 
         # move tensors to device
-        lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()}
+        lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()}  # type: ignore[attr-defined]
 
-        state_dict = self.state_dict()
+        state_dict = self.state_dict()  # type: ignore[attr-defined]
 
-        if self.args.lora is None:
+        if self.args.lora is None:  # type: ignore[attr-defined]
             logging.info("Loading and merging LoRA weights...")
 
             # replace every nn.Linear with a LoRALinear with 'meta' device except the output layer
-            named_modules = dict(self.named_modules())
+            named_modules = dict(self.named_modules())  # type: ignore[attr-defined]
             for name, module in named_modules.items():
                 if isinstance(module, nn.Linear) and name != "output":
                     layer_id = name.split(".")[1]
-                    if layer_id not in self.layers:
+                    if layer_id not in self.layers:  # type: ignore[attr-defined]
                         logging.debug(
                             "Skipping parameter %s at pipeline rank %d",
                             name,
-                            self.pipeline_rank,
+                            self.pipeline_rank,  # type: ignore[attr-defined]
                         )
-                    else:
+                    elif (name + ".lora_B.weight") in lora_state_dict:
                         weight = (
                             module.weight
-                            + (
-                                lora_state_dict[name + ".lora_B.weight"]
-                                @ lora_state_dict[name + ".lora_A.weight"]
-                            )
+                            + (lora_state_dict[name + ".lora_B.weight"] @ lora_state_dict[name + ".lora_A.weight"])
                             * scaling
                         )
 
@@ -154,13 +143,13 @@ class LoRALoaderMixin:
                 state_dict.update(lora_state_dict)
 
                 layer_id = k.split(".")[1]
-                if layer_id in self.layers:
+                if layer_id in self.layers:  # type: ignore[attr-defined]
                     state_dict[k] = v
                 else:
                     logging.debug(
                         "Skipping parameter %s at pipeline rank %d",
                         k,
-                        self.pipeline_rank,
+                        self.pipeline_rank,  # type: ignore[attr-defined]
                     )
 
-        self.load_state_dict(state_dict, strict=True)
+        self.load_state_dict(state_dict, strict=True)  # type: ignore[attr-defined]
diff --git a/src/mistral_inference/main.py b/src/mistral_inference/main.py
index a5ef3a0ddfd5eb126f885316a1e94969b809e0c8..9cafec8da072636a5a5c267c899f1b46907fbc07 100644
--- a/src/mistral_inference/main.py
+++ b/src/mistral_inference/main.py
@@ -1,7 +1,9 @@
+import json
 import logging
 import os
+import warnings
 from pathlib import Path
-from typing import List, Optional
+from typing import List, Optional, Type, Union
 
 import fire  # type: ignore
 import torch
@@ -11,8 +13,9 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
 from mistral_common.tokens.tokenizers.base import Tokenizer
 from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
 
-from mistral_inference.generate import generate
-from mistral_inference.model import Transformer
+from mistral_inference.generate import generate, generate_mamba
+from mistral_inference.mamba import Mamba
+from mistral_inference.transformer import Transformer
 
 
 def is_torchrun() -> bool:
@@ -21,9 +24,7 @@ def is_torchrun() -> bool:
 
 
 def load_tokenizer(model_path: Path) -> MistralTokenizer:
-    tokenizer = [
-        f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")
-    ]
+    tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")]
     assert (
         len(tokenizer) > 0
     ), f"No tokenizer found in {model_path}, make sure to place a `tokenizer.model.[v1,v2,v3]` file in {model_path}."
@@ -33,13 +34,28 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer:
 
     mistral_tokenizer = MistralTokenizer.from_file(str(model_path / tokenizer[0]))
 
-    logging.info(
-        f"Loaded tokenizer of type {mistral_tokenizer.instruct_tokenizer.__class__}"
-    )
+    logging.info(f"Loaded tokenizer of type {mistral_tokenizer.instruct_tokenizer.__class__}")
 
     return mistral_tokenizer
 
 
+def get_model_cls(model_path: str) -> Union[Type[Mamba], Type[Transformer]]:
+    with open(Path(model_path) / "params.json", "r") as f:
+        args_dict = json.load(f)
+
+    return {"mamba": Mamba, "transformer": Transformer}[args_dict.get("model_type", "transformer")]  # type: ignore[return-value]
+
+
+def pad_and_convert_to_tensor(list_of_lists: List[List[int]], pad_id: int) -> List[List[int]]:
+    # Determine the length of the longest list
+    max_len = max(len(lst) for lst in list_of_lists)
+
+    # Left pad each list to the maximum length
+    padded_lists = [[pad_id] * (max_len - len(lst)) + lst for lst in list_of_lists]
+
+    return padded_lists
+
+
 def interactive(
     model_path: str,
     max_tokens: int = 35,
@@ -61,13 +77,12 @@ def interactive(
     mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path))
     tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
 
-    transformer = Transformer.from_folder(
-        Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks
-    )
+    model_cls = get_model_cls(model_path)
+    model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)
 
     # load LoRA
     if lora_path is not None:
-        transformer.load_lora(Path(lora_path))
+        model.load_lora(Path(lora_path))
 
     prompt: str = ""
     messages: List[UserMessage | AssistantMessage] = []
@@ -80,9 +95,7 @@ def interactive(
                 messages += [UserMessage(content=user_input)]
                 chat_completion_request = ChatCompletionRequest(messages=messages)
 
-                tokens = mistral_tokenizer.encode_chat_completion(
-                    chat_completion_request
-                ).tokens
+                tokens = mistral_tokenizer.encode_chat_completion(chat_completion_request).tokens
             else:
                 prompt += user_input
 
@@ -98,9 +111,10 @@ def interactive(
         if not should_print:
             tokens = int(length_tensor.item()) * [0]
 
-        generated_tokens, _ = generate(
+        generate_fn = generate if isinstance(model, Transformer) else generate_mamba
+        generated_tokens, _ = generate_fn(  # type: ignore[operator]
             [tokens],
-            transformer,
+            model,
             max_tokens=max_tokens,
             temperature=temperature,
             eos_id=tokenizer.eos_id,
@@ -134,12 +148,11 @@ def demo(
         should_print = True
         num_pipeline_ranks = 1
 
-    transformer = Transformer.from_folder(
-        Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks
-    )
+    model_cls = get_model_cls(model_path)
+    model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)
     # load LoRA
     if lora_path is not None:
-        transformer.load_lora(Path(lora_path))
+        model.load_lora(Path(lora_path))
 
     mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path))
     tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
@@ -150,21 +163,30 @@ def demo(
         "This is a third test, mistral AI is very good at testing. ",
     ]
 
-    encoded_prompts = [
-        tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts
-    ]
+    encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts]
+
+    if isinstance(model, Transformer):
+        generate_fn = generate
+    else:
+        generate_fn = generate_mamba  # type: ignore[assignment]
+        warnings.warn(
+            "Batched generation is not correctly supported at the moment and therefore might lead to worse results "
+            "as compared to non-batched generation. "
+            "See https://github.com/state-spaces/mamba/issues/66#issuecomment-1862349718 for more information."
+        )
+        encoded_prompts = pad_and_convert_to_tensor(encoded_prompts, mistral_tokenizer.instruct_tokenizer.BOS)  # type: ignore[attr-defined]
 
-    generated_tokens, _logprobs = generate(
+    generated_tokens, _logprobs = generate_fn(
         encoded_prompts,
-        transformer,
+        model,  # type: ignore[arg-type]
         max_tokens=max_tokens,
         temperature=temperature,
         eos_id=tokenizer.eos_id,
     )
 
     generated_words = []
-    for i, x in enumerate(encoded_prompts):
-        generated_words.append(tokenizer.decode(x + generated_tokens[i]))
+    for i, x in enumerate(generated_tokens):
+        generated_words.append(tokenizer.decode(encoded_prompts[i] + x))
 
     res = generated_words
 
diff --git a/src/mistral_inference/mamba.py b/src/mistral_inference/mamba.py
new file mode 100644
index 0000000000000000000000000000000000000000..02745e3d27183e0b6a6b942e235b6c19a10d7c11
--- /dev/null
+++ b/src/mistral_inference/mamba.py
@@ -0,0 +1,83 @@
+import json
+from pathlib import Path
+from typing import List, Optional, Union
+
+import safetensors
+import torch
+import torch.nn as nn
+
+from mistral_inference.args import MambaArgs
+from mistral_inference.cache import BufferCache
+from mistral_inference.model import ModelBase
+
+_is_mamba_installed = False
+try:
+    from mamba_ssm.models.config_mamba import MambaConfig
+    from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
+
+    _is_mamba_installed = True
+except ImportError:
+    _is_mamba_installed = False
+
+
+class Mamba(ModelBase, nn.Module):
+    def __init__(self, args: MambaArgs):
+        super().__init__()
+        self.args = args
+        assert _is_mamba_installed, "Mamba is not installed. Please install it using `pip install mamba-ssm`."
+
+        # make sure naming is consistent with `mamba_ssm`
+        config = MambaConfig(
+            d_model=args.dim,
+            n_layer=args.n_layers,
+            vocab_size=args.vocab_size,
+            ssm_cfg={"ngroups": args.n_groups, "layer": "Mamba2"},
+            attn_layer_idx=[],
+            attn_cfg={},
+            rms_norm=args.rms_norm,
+            residual_in_fp32=args.residual_in_fp32,
+            fused_add_norm=args.fused_add_norm,
+            pad_vocab_size_multiple=args.pad_vocab_size_multiple,
+            tie_embeddings=args.tie_embeddings,
+        )
+        self.model = MambaLMHeadModel(config)
+
+    @property
+    def dtype(self) -> torch.dtype:
+        return next(self.parameters()).dtype
+
+    @property
+    def device(self) -> torch.device:
+        return next(self.parameters()).device
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        seqlens: List[int],  # not supported for now
+        cache: Optional[BufferCache] = None,  # not supported for now
+    ) -> torch.Tensor:
+        lm_output = self.model(input_ids)
+        result: torch.Tensor = lm_output.logits
+        return result
+
+    @staticmethod
+    def from_folder(
+        folder: Union[Path, str],
+        max_batch_size: int = 1,
+        num_pipeline_ranks: int = 1,
+        device: Union[torch.device, str] = "cuda",
+        dtype: Optional[torch.dtype] = None,
+    ) -> "Mamba":
+        with open(Path(folder) / "params.json", "r") as f:
+            model_args = MambaArgs.from_dict(json.load(f))
+
+        with torch.device("meta"):
+            model = Mamba(model_args)
+
+        model_file = Path(folder) / "consolidated.safetensors"
+
+        assert model_file.exists(), f"Make sure {model_file} exists."
+        loaded = safetensors.torch.load_file(str(model_file))
+
+        model.load_state_dict(loaded, assign=True, strict=True)
+        return model.to(device=device, dtype=dtype)
diff --git a/src/mistral_inference/model.py b/src/mistral_inference/model.py
index f5ebd7a3739f446b3677cb486b6c0bc96147d79c..be41fcdbe8117f3dd0a4504d6c4ce08e460037d0 100644
--- a/src/mistral_inference/model.py
+++ b/src/mistral_inference/model.py
@@ -1,409 +1,43 @@
-import json
-import logging
-import math
-from dataclasses import dataclass
-from functools import partial
+from abc import ABC, abstractmethod
 from pathlib import Path
-from typing import Any, List, Mapping, Optional, Tuple, Union
+from typing import List, Optional, Union
 
-import safetensors.torch
 import torch
-from simple_parsing.helpers import Serializable
-from torch import nn
-from xformers.ops.fmha import memory_efficient_attention  # type: ignore
+import torch.nn as nn
 
-from mistral_inference.cache import (
-    BufferCache,
-    CacheInputMetadata,
-    CacheView,
-)
-from mistral_inference.lora import LoraArgs, LoRALinear, LoRALoaderMixin
-from mistral_inference.moe import MoeArgs, MoeLayer
-from mistral_inference.rope import apply_rotary_emb, precompute_freqs_cis
+from mistral_inference.cache import BufferCache
 
 
-@dataclass
-class ModelArgs(Serializable):
-    dim: int
-    n_layers: int
-    head_dim: int
-    hidden_dim: int
-    n_heads: int
-    n_kv_heads: int
-    norm_eps: float
-    vocab_size: int
-
-    max_batch_size: int = 0
-
-    # For rotary embeddings. If not set, will be inferred
-    rope_theta: Optional[float] = None
-    # If this is set, we will use MoE layers instead of dense layers.
-    moe: Optional[MoeArgs] = None
-    # If this is set, we will load LoRA linear layers instead of linear layers.
-    lora: Optional[LoraArgs] = None
-
-
-@dataclass
-class SimpleInputMetadata:
-    # rope absolute positions
-    positions: torch.Tensor
-
-    @staticmethod
-    def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata":
-        return SimpleInputMetadata(
-            positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to(
-                device=device, dtype=torch.long
-            )
-        )
-
-
-def repeat_kv(
-    keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
-    values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
-    return keys, values
-
-
-def maybe_lora(args: ModelArgs) -> Union[nn.Linear, LoRALinear]:
-    if args.lora is None:
-        return nn.Linear
-    else:
-        return partial(LoRALinear, rank=args.lora.rank, scaling=args.lora.scaling)
-
-
-class Attention(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-        self.args = args
-
-        self.n_heads: int = args.n_heads
-        self.head_dim: int = args.head_dim
-        self.n_kv_heads: int = args.n_kv_heads
-
-        self.repeats = self.n_heads // self.n_kv_heads
-
-        self.scale = self.args.head_dim**-0.5
-
-        MaybeLora = maybe_lora(args)
-        self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False)
-        self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
-        self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
-        self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        freqs_cis: torch.Tensor,
-        cache: Optional[CacheView],
-    ) -> torch.Tensor:
-        seqlen_sum, _ = x.shape
-
-        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
-        xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
-        xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
-        xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
-        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
-
-        if cache is None:
-            key, val = xk, xv
-        elif cache.prefill:
-            key, val = cache.interleave_kv(xk, xv)
-            cache.update(xk, xv)
-        else:
-            cache.update(xk, xv)
-            key, val = cache.key, cache.value
-            key = key.view(
-                seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim
-            )
-            val = val.view(
-                seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim
-            )
-
-        # Repeat keys and values to match number of query heads
-        key, val = repeat_kv(key, val, self.repeats, dim=1)
-
-        # xformers requires (B=1, S, H, D)
-        xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
-        output = memory_efficient_attention(
-            xq, key, val, None if cache is None else cache.mask
-        )
-        output = output.view(seqlen_sum, self.n_heads * self.head_dim)
-
-        assert isinstance(output, torch.Tensor)
-
-        return self.wo(output)  # type: ignore
-
-
-class FeedForward(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-
-        MaybeLora = maybe_lora(args)
-        self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False)
-        self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False)
-        self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))  # type: ignore
-
-
-class RMSNorm(torch.nn.Module):
-    def __init__(self, dim: int, eps: float = 1e-6):
+class ModelBase(nn.Module, ABC):
+    def __init__(self) -> None:
         super().__init__()
-        self.eps = eps
-        self.weight = nn.Parameter(torch.ones(dim))
-
-    def _norm(self, x: torch.Tensor) -> torch.Tensor:
-        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        output = self._norm(x.float()).type_as(x)
-        return output * self.weight
-
-
-class TransformerBlock(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-        self.n_heads = args.n_heads
-        self.dim = args.dim
-        self.attention = Attention(args)
-        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
-        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
-        self.args = args
-
-        self.feed_forward: nn.Module
-        if args.moe is not None:
-            self.feed_forward = MoeLayer(
-                experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
-                gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
-                moe_args=args.moe,
-            )
-        else:
-            self.feed_forward = FeedForward(args=args)
-
-    def forward(
-        self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]
-    ) -> torch.Tensor:
-        r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
-        h = x + r
-        r = self.feed_forward.forward(self.ffn_norm(h))
-        out = h + r
-        return out
-
-
-class Transformer(nn.Module, LoRALoaderMixin):
-    def __init__(
-        self,
-        args: ModelArgs,
-        pipeline_rank: int = 0,
-        num_pipeline_ranks: int = 1,
-    ):
-        super().__init__()
-        self.args = args
-        self.vocab_size = args.vocab_size
-        self.n_layers = args.n_layers
-        self._precomputed_freqs_cis: Optional[torch.Tensor] = None
-        assert self.vocab_size > 0
-        assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
-        self.pipeline_rank = pipeline_rank
-        self.num_pipeline_ranks = num_pipeline_ranks
-        # Modules specific to some ranks:
-        self.tok_embeddings: Optional[nn.Embedding] = None
-        self.norm: Optional[RMSNorm] = None
-        self.output: Optional[nn.Linear] = None
-        if pipeline_rank == 0:
-            self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
-        if pipeline_rank == num_pipeline_ranks - 1:
-            self.norm = RMSNorm(args.dim, eps=args.norm_eps)
-            self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
-        # Initialize all layers but slice off those not of this rank.
-        layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
-        num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks)
-        offset = self.pipeline_rank * num_layers_per_rank
-        end = min(self.n_layers, offset + num_layers_per_rank)
-        self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)})
-        self.n_local_layers = len(self.layers)
 
     @property
+    @abstractmethod
     def dtype(self) -> torch.dtype:
-        return next(self.parameters()).dtype
+        pass
 
     @property
+    @abstractmethod
     def device(self) -> torch.device:
-        return next(self.parameters()).device
-
-    @property
-    def freqs_cis(self) -> torch.Tensor:
-        # We cache freqs_cis but need to take care that it is on the right device
-        # and has the right dtype (complex64). The fact that the dtype is different
-        # from the module's  dtype means we cannot register it as a buffer
-        if self._precomputed_freqs_cis is None:
-            # default to 10**6
-            theta = self.args.rope_theta or 1000000.0
-            self._precomputed_freqs_cis = precompute_freqs_cis(
-                self.args.head_dim, 128_000, theta
-            )
-
-        if self._precomputed_freqs_cis.device != self.device:
-            self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(
-                device=self.device
-            )
-        return self._precomputed_freqs_cis
-
-    def forward_partial(
-        self,
-        input_ids: torch.Tensor,
-        seqlens: List[int],
-        cache: Optional[BufferCache] = None,
-    ) -> torch.Tensor:
-        """Local forward pass.
-
-        If doing pipeline parallelism, this will return the activations of the last layer of this stage.
-        For the last stage, this will return the normalized final embeddings.
-        """
-        assert (
-            len(seqlens) <= self.args.max_batch_size
-        ), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
-        (num_toks,) = input_ids.shape
-        assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
-
-        input_metadata: Union[CacheInputMetadata, SimpleInputMetadata]
-
-        if cache is not None:
-            input_metadata = cache.get_input_metadata(seqlens)
-        else:
-            input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
-
-        if self.pipeline_rank == 0:
-            assert self.tok_embeddings is not None
-            h = self.tok_embeddings(input_ids)
-        else:
-            h = torch.empty(
-                num_toks, self.args.dim, device=self.device, dtype=self.dtype
-            )
-            torch.distributed.recv(h, src=self.pipeline_rank - 1)
-
-        freqs_cis = self.freqs_cis[input_metadata.positions]
-
-        for local_layer_id, layer in enumerate(self.layers.values()):
-            if cache is not None:
-                assert input_metadata is not None
-                assert isinstance(input_metadata, CacheInputMetadata)
-                cache_view = cache.get_view(local_layer_id, input_metadata)
-            else:
-                cache_view = None
-            h = layer(h, freqs_cis, cache_view)
-
-        if cache is not None:
-            cache.update_seqlens(seqlens)
-        if self.pipeline_rank < self.num_pipeline_ranks - 1:
-            torch.distributed.send(h, dst=self.pipeline_rank + 1)
-            return h  # type: ignore
-        else:
-            # Last rank has a final normalization step.
-            assert self.norm is not None
-            return self.norm(h)  # type: ignore
+        pass
 
+    @abstractmethod
     def forward(
         self,
         input_ids: torch.Tensor,
-        seqlens: List[int],
-        cache: Optional[BufferCache] = None,
+        seqlens: List[int],  # not supported for now
+        cache: Optional[BufferCache] = None,  # not supported for now
     ) -> torch.Tensor:
-        h = self.forward_partial(input_ids, seqlens, cache=cache)
-        if self.pipeline_rank < self.num_pipeline_ranks - 1:
-            # ignore the intermediate activations as we'll get the final output from
-            # the last stage
-            outs = torch.empty(
-                h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype
-            )
-        else:
-            assert self.output is not None
-            outs = self.output(h)
-        if self.num_pipeline_ranks > 1:
-            torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
-        return outs.float()
-
-    def load_state_dict(
-        self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
-    ) -> None:
-        state_to_load = {}
-        skipped = set([])
-        for k, v in state_dict.items():
-            if k.startswith("tok_embeddings"):
-                if self.pipeline_rank == 0:
-                    state_to_load[k] = v
-                else:
-                    logging.debug(
-                        "Skipping parameter %s at pipeline rank %d",
-                        k,
-                        self.pipeline_rank,
-                    )
-                    skipped.add(k)
-            elif k.startswith("norm") or k.startswith("output"):
-                if self.pipeline_rank == self.num_pipeline_ranks - 1:
-                    state_to_load[k] = v
-                else:
-                    logging.debug(
-                        "Skipping parameter %s at pipeline rank %d",
-                        k,
-                        self.pipeline_rank,
-                    )
-                    skipped.add(k)
-            elif k.startswith("layers"):
-                layer_id = k.split(".")[1]
-                if layer_id in self.layers:
-                    state_to_load[k] = v
-                else:
-                    logging.debug(
-                        "Skipping parameter %s at pipeline rank %d",
-                        k,
-                        self.pipeline_rank,
-                    )
-                    skipped.add(k)
-            else:
-                raise ValueError(f"Unexpected key {k}")
-        assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys()))
-        super().load_state_dict(state_to_load, strict=strict, assign=assign)
+        pass
 
     @staticmethod
+    @abstractmethod
     def from_folder(
         folder: Union[Path, str],
         max_batch_size: int = 1,
         num_pipeline_ranks: int = 1,
         device: Union[torch.device, str] = "cuda",
         dtype: Optional[torch.dtype] = None,
-    ) -> "Transformer":
-        with open(Path(folder) / "params.json", "r") as f:
-            model_args = ModelArgs.from_dict(json.load(f))
-        model_args.max_batch_size = max_batch_size
-        if num_pipeline_ranks > 1:
-            pipeline_rank = torch.distributed.get_rank()
-        else:
-            pipeline_rank = 0
-        with torch.device("meta"):
-            model = Transformer(
-                model_args,
-                pipeline_rank=pipeline_rank,
-                num_pipeline_ranks=num_pipeline_ranks,
-            )
-
-        pt_model_file = Path(folder) / "consolidated.00.pth"
-        safetensors_model_file = Path(folder) / "consolidated.safetensors"
-
-        assert (
-            pt_model_file.exists() or safetensors_model_file.exists()
-        ), f"Make sure either {pt_model_file} or {safetensors_model_file} exists"
-        assert not (
-            pt_model_file.exists() and safetensors_model_file.exists()
-        ), f"Both {pt_model_file} and {safetensors_model_file} cannot exist"
-
-        if pt_model_file.exists():
-            loaded = torch.load(str(pt_model_file), mmap=True)
-        else:
-            loaded = safetensors.torch.load_file(str(safetensors_model_file))
-
-        model.load_state_dict(loaded, assign=True, strict=True)
-
-        return model.to(device=device, dtype=dtype)
+    ) -> "ModelBase":
+        pass
diff --git a/src/mistral_inference/moe.py b/src/mistral_inference/moe.py
index 043776b2c989bba5eb8889f0cfc6575a335ef9d7..9ce8a8a9c62e9a8edff797c0262b393c30ea38e7 100644
--- a/src/mistral_inference/moe.py
+++ b/src/mistral_inference/moe.py
@@ -23,14 +23,10 @@ class MoeLayer(nn.Module):
 
     def forward(self, inputs: torch.Tensor) -> torch.Tensor:
         gate_logits = self.gate(inputs)
-        weights, selected_experts = torch.topk(
-            gate_logits, self.args.num_experts_per_tok
-        )
+        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
         weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
         results = torch.zeros_like(inputs)
         for i, expert in enumerate(self.experts):
             batch_idx, nth_expert = torch.where(selected_experts == i)
-            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
-                inputs[batch_idx]
-            )
+            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx])
         return results
diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..57546845bdd1e1847ccdc5b40ca19c7874b1e787
--- /dev/null
+++ b/src/mistral_inference/transformer.py
@@ -0,0 +1,367 @@
+import json
+import logging
+import math
+from dataclasses import dataclass
+from functools import partial
+from pathlib import Path
+from typing import Any, List, Mapping, Optional, Tuple, Type, Union
+
+import safetensors.torch
+import torch
+from torch import nn
+from xformers.ops.fmha import memory_efficient_attention  # type: ignore
+
+from mistral_inference.args import TransformerArgs
+from mistral_inference.cache import (
+    BufferCache,
+    CacheInputMetadata,
+    CacheView,
+)
+from mistral_inference.lora import LoRALinear, LoRALoaderMixin
+from mistral_inference.model import ModelBase
+from mistral_inference.moe import MoeLayer
+from mistral_inference.rope import apply_rotary_emb, precompute_freqs_cis
+
+
+@dataclass
+class SimpleInputMetadata:
+    # rope absolute positions
+    positions: torch.Tensor
+
+    @staticmethod
+    def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata":
+        return SimpleInputMetadata(
+            positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to(device=device, dtype=torch.long)
+        )
+
+
+def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
+    keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
+    values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
+    return keys, values
+
+
+def maybe_lora(args: TransformerArgs) -> Union[Type[nn.Linear], partial[LoRALinear]]:
+    if args.lora is None:
+        return nn.Linear
+    else:
+        return partial(LoRALinear, rank=args.lora.rank, scaling=args.lora.scaling)
+
+
+class Attention(nn.Module):
+    def __init__(self, args: TransformerArgs):
+        super().__init__()
+        self.args = args
+
+        self.n_heads: int = args.n_heads
+        self.head_dim: int = args.head_dim
+        self.n_kv_heads: int = args.n_kv_heads
+
+        self.repeats = self.n_heads // self.n_kv_heads
+
+        self.scale = self.args.head_dim**-0.5
+
+        MaybeLora = maybe_lora(args)
+        self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False)
+        self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
+        self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
+        self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        freqs_cis: torch.Tensor,
+        cache: Optional[CacheView],
+    ) -> torch.Tensor:
+        seqlen_sum, _ = x.shape
+
+        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+        xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
+        xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
+        xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
+        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+        if cache is None:
+            key, val = xk, xv
+        elif cache.prefill:
+            key, val = cache.interleave_kv(xk, xv)
+            cache.update(xk, xv)
+        else:
+            cache.update(xk, xv)
+            key, val = cache.key, cache.value
+            key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
+            val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
+
+        # Repeat keys and values to match number of query heads
+        key, val = repeat_kv(key, val, self.repeats, dim=1)
+
+        # xformers requires (B=1, S, H, D)
+        xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
+        output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)
+        output = output.view(seqlen_sum, self.n_heads * self.head_dim)
+
+        assert isinstance(output, torch.Tensor)
+
+        return self.wo(output)  # type: ignore
+
+
+class FeedForward(nn.Module):
+    def __init__(self, args: TransformerArgs):
+        super().__init__()
+
+        MaybeLora = maybe_lora(args)
+        self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False)
+        self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False)
+        self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))  # type: ignore
+
+
+class RMSNorm(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x: torch.Tensor) -> torch.Tensor:
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, args: TransformerArgs):
+        super().__init__()
+        self.n_heads = args.n_heads
+        self.dim = args.dim
+        self.attention = Attention(args)
+        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
+        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
+        self.args = args
+
+        self.feed_forward: nn.Module
+        if args.moe is not None:
+            self.feed_forward = MoeLayer(
+                experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
+                gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
+                moe_args=args.moe,
+            )
+        else:
+            self.feed_forward = FeedForward(args=args)
+
+    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]) -> torch.Tensor:
+        r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
+        h = x + r
+        r = self.feed_forward.forward(self.ffn_norm(h))
+        out = h + r
+        return out
+
+
+class Transformer(ModelBase, LoRALoaderMixin):
+    def __init__(
+        self,
+        args: TransformerArgs,
+        pipeline_rank: int = 0,
+        num_pipeline_ranks: int = 1,
+    ):
+        super().__init__()
+        self.args = args
+        self.vocab_size = args.vocab_size
+        self.n_layers = args.n_layers
+        self._precomputed_freqs_cis: Optional[torch.Tensor] = None
+        assert self.vocab_size > 0
+        assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
+        self.pipeline_rank = pipeline_rank
+        self.num_pipeline_ranks = num_pipeline_ranks
+        # Modules specific to some ranks:
+        self.tok_embeddings: Optional[nn.Embedding] = None
+        self.norm: Optional[RMSNorm] = None
+        self.output: Optional[nn.Linear] = None
+        if pipeline_rank == 0:
+            self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
+        if pipeline_rank == num_pipeline_ranks - 1:
+            self.norm = RMSNorm(args.dim, eps=args.norm_eps)
+            self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
+        # Initialize all layers but slice off those not of this rank.
+        layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
+        num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks)
+        offset = self.pipeline_rank * num_layers_per_rank
+        end = min(self.n_layers, offset + num_layers_per_rank)
+        self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)})
+        self.n_local_layers = len(self.layers)
+
+    @property
+    def dtype(self) -> torch.dtype:
+        return next(self.parameters()).dtype
+
+    @property
+    def device(self) -> torch.device:
+        return next(self.parameters()).device
+
+    @property
+    def freqs_cis(self) -> torch.Tensor:
+        # We cache freqs_cis but need to take care that it is on the right device
+        # and has the right dtype (complex64). The fact that the dtype is different
+        # from the module's  dtype means we cannot register it as a buffer
+        if self._precomputed_freqs_cis is None:
+            # default to 10**6
+            theta = self.args.rope_theta or 1000000.0
+            self._precomputed_freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta)
+
+        if self._precomputed_freqs_cis.device != self.device:
+            self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device)
+        return self._precomputed_freqs_cis
+
+    def forward_partial(
+        self,
+        input_ids: torch.Tensor,
+        seqlens: List[int],
+        cache: Optional[BufferCache] = None,
+    ) -> torch.Tensor:
+        """Local forward pass.
+
+        If doing pipeline parallelism, this will return the activations of the last layer of this stage.
+        For the last stage, this will return the normalized final embeddings.
+        """
+        assert (
+            len(seqlens) <= self.args.max_batch_size
+        ), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
+        (num_toks,) = input_ids.shape
+        assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
+
+        input_metadata: Union[CacheInputMetadata, SimpleInputMetadata]
+
+        if cache is not None:
+            input_metadata = cache.get_input_metadata(seqlens)
+        else:
+            input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
+
+        if self.pipeline_rank == 0:
+            assert self.tok_embeddings is not None
+            h = self.tok_embeddings(input_ids)
+        else:
+            h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype)
+            torch.distributed.recv(h, src=self.pipeline_rank - 1)
+
+        freqs_cis = self.freqs_cis[input_metadata.positions]
+
+        for local_layer_id, layer in enumerate(self.layers.values()):
+            if cache is not None:
+                assert input_metadata is not None
+                assert isinstance(input_metadata, CacheInputMetadata)
+                cache_view = cache.get_view(local_layer_id, input_metadata)
+            else:
+                cache_view = None
+            h = layer(h, freqs_cis, cache_view)
+
+        if cache is not None:
+            cache.update_seqlens(seqlens)
+        if self.pipeline_rank < self.num_pipeline_ranks - 1:
+            torch.distributed.send(h, dst=self.pipeline_rank + 1)
+            return h  # type: ignore
+        else:
+            # Last rank has a final normalization step.
+            assert self.norm is not None
+            return self.norm(h)  # type: ignore
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        seqlens: List[int],
+        cache: Optional[BufferCache] = None,
+    ) -> torch.Tensor:
+        h = self.forward_partial(input_ids, seqlens, cache=cache)
+        if self.pipeline_rank < self.num_pipeline_ranks - 1:
+            # ignore the intermediate activations as we'll get the final output from
+            # the last stage
+            outs = torch.empty(h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype)
+        else:
+            assert self.output is not None
+            outs = self.output(h)
+        if self.num_pipeline_ranks > 1:
+            torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
+        return outs.float()
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None:
+        state_to_load = {}
+        skipped = set([])
+        for k, v in state_dict.items():
+            if k.startswith("tok_embeddings"):
+                if self.pipeline_rank == 0:
+                    state_to_load[k] = v
+                else:
+                    logging.debug(
+                        "Skipping parameter %s at pipeline rank %d",
+                        k,
+                        self.pipeline_rank,
+                    )
+                    skipped.add(k)
+            elif k.startswith("norm") or k.startswith("output"):
+                if self.pipeline_rank == self.num_pipeline_ranks - 1:
+                    state_to_load[k] = v
+                else:
+                    logging.debug(
+                        "Skipping parameter %s at pipeline rank %d",
+                        k,
+                        self.pipeline_rank,
+                    )
+                    skipped.add(k)
+            elif k.startswith("layers"):
+                layer_id = k.split(".")[1]
+                if layer_id in self.layers:
+                    state_to_load[k] = v
+                else:
+                    logging.debug(
+                        "Skipping parameter %s at pipeline rank %d",
+                        k,
+                        self.pipeline_rank,
+                    )
+                    skipped.add(k)
+            else:
+                raise ValueError(f"Unexpected key {k}")
+        assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys()))
+        super().load_state_dict(state_to_load, strict=strict, assign=assign)
+
+    @staticmethod
+    def from_folder(
+        folder: Union[Path, str],
+        max_batch_size: int = 1,
+        num_pipeline_ranks: int = 1,
+        device: Union[torch.device, str] = "cuda",
+        dtype: Optional[torch.dtype] = None,
+    ) -> "Transformer":
+        with open(Path(folder) / "params.json", "r") as f:
+            model_args = TransformerArgs.from_dict(json.load(f))
+        model_args.max_batch_size = max_batch_size
+        if num_pipeline_ranks > 1:
+            pipeline_rank = torch.distributed.get_rank()
+        else:
+            pipeline_rank = 0
+        with torch.device("meta"):
+            model = Transformer(
+                model_args,
+                pipeline_rank=pipeline_rank,
+                num_pipeline_ranks=num_pipeline_ranks,
+            )
+
+        pt_model_file = Path(folder) / "consolidated.00.pth"
+        safetensors_model_file = Path(folder) / "consolidated.safetensors"
+
+        assert (
+            pt_model_file.exists() or safetensors_model_file.exists()
+        ), f"Make sure either {pt_model_file} or {safetensors_model_file} exists"
+        assert not (
+            pt_model_file.exists() and safetensors_model_file.exists()
+        ), f"Both {pt_model_file} and {safetensors_model_file} cannot exist"
+
+        if pt_model_file.exists():
+            loaded = torch.load(str(pt_model_file), mmap=True)
+        else:
+            loaded = safetensors.torch.load_file(str(safetensors_model_file))
+
+        model.load_state_dict(loaded, assign=True, strict=True)
+
+        return model.to(device=device, dtype=dtype)
diff --git a/tests/test_generate.py b/tests/test_generate.py
index bf9f1174c6b190f50f9aec29b5b397cb23cb606e..00bfad4efd209d33a0dea0938236f245ab37a7ae 100644
--- a/tests/test_generate.py
+++ b/tests/test_generate.py
@@ -1,8 +1,10 @@
 from typing import List
 
 import torch
+from mistral_inference.generate import generate_mamba
 from mistral_inference.main import generate
-from mistral_inference.model import ModelArgs, Transformer
+from mistral_inference.mamba import Mamba, MambaArgs
+from mistral_inference.transformer import Transformer, TransformerArgs
 
 
 class DebugTokenizer:
@@ -29,11 +31,11 @@ class DebugTokenizer:
         return " ".join([str(x) for x in t])
 
 
-def test_generation():
+def test_generation_transformer():
     torch.manual_seed(42)
 
     sequences = ["1 2 3 4 5 6 7", "0 1 2", "12 13 14", "2 4 34"]
-    args = ModelArgs(
+    args = TransformerArgs(
         dim=512,
         n_layers=1,
         head_dim=128,
@@ -53,30 +55,51 @@ def test_generation():
     # concat generated and prompt
     encoded = [e + t for e, t in zip(encoded, toks)]
 
-    generated, all_logprobs_new = generate(
-        encoded, model, temperature=0.0, max_tokens=0
-    )
+    generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0)
 
     assert generated == []
 
     # Verify that logprobs are the same
     assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new)
     for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
-        assert all(
-            [abs(x - y) < 1e-5 for x, y in zip(lp_old, lp_new)]
-        ), f"\n{lp_old}\n{lp_new}"
+        assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}"
 
     print("All tests passed.")
 
 
-def test_chunks():
+def test_generation_mamba():
+    torch.manual_seed(42)
+
+    sequences = ["1 2 3 4 5 6 7"]
+    args = MambaArgs(
+        dim=512,
+        n_layers=1,
+        n_groups=1,
+        rms_norm=True,
+        residual_in_fp32=True,
+        fused_add_norm=True,
+        pad_vocab_size_multiple=1,
+        tie_embeddings=False,
+        vocab_size=32768,
+    )
+    model = Mamba(args).to("cuda", dtype=torch.float32)
+    tokenizer = DebugTokenizer()
+
+    encoded = [tokenizer.encode(s, bos=True) for s in sequences]
+    toks, all_logprobs_old = generate_mamba(encoded, model, temperature=0.0, max_tokens=7)
+
+    assert len(toks[0]) == 7
+    assert toks == [[25574, 14821, 11843, 23698, 12735, 23522, 27542]]
+
+
+def test_chunks_transformer():
     torch.manual_seed(42)
 
     sequences = [
         " ".join([str(i) for i in range(7)]),
         " ".join([str(i) for i in range(9, 0, -1)]),
     ]
-    args = ModelArgs(
+    args = TransformerArgs(
         dim=512,
         n_layers=1,
         head_dim=128,
@@ -96,17 +119,8 @@ def test_chunks():
     # concat generated and prompt
     encoded = [e + t for e, t in zip(encoded, toks)]
 
-    generated, all_logprobs_new = generate(
-        encoded, model, temperature=0.0, max_tokens=0, chunk_size=5
-    )
+    generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0, chunk_size=5)
     assert len(generated) == 0
 
     for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
-        assert all(
-            [abs(x - y) < 1e-5 for x, y in zip(lp_old, lp_new)]
-        ), f"\n{lp_old}\n{lp_new}"
-
-
-if __name__ == "__main__":
-    test_generation()
-    test_chunks()
+        assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}"
diff --git a/tutorials/getting_started.ipynb b/tutorials/getting_started.ipynb
index a3e70b17b2ddcb77ad3ae5ff64f46218cdac227d..747b807cef2f3ffe7fae7b7fb94440d31126d448 100644
--- a/tutorials/getting_started.ipynb
+++ b/tutorials/getting_started.ipynb
@@ -82,7 +82,7 @@
    "source": [
     "import os \n",
     "\n",
-    "from mistral_inference.model import Transformer\n",
+    "from mistral_inference.transformer import Transformer\n",
     "from mistral_inference.generate import generate\n",
     "\n",
     "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",