diff --git a/pyproject.toml b/pyproject.toml
index 5528e1790fc6f73a39060711302796e0fd0aa74c..c3895493f39e07b4665f83a4f86ccae481c6982d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "mistral_inference"
-version = "1.3.1"
+version = "1.4.0"
 description = ""
 authors = ["bam4d <bam4d@mistral.ai>"]
 readme = "README.md"
@@ -27,8 +27,9 @@ python = "^3.9.10"
 xformers = ">=0.0.24"
 simple-parsing = ">=0.1.5"
 fire = ">=0.6.0"
-mistral_common = "^1.3.0"
+mistral_common = ">=1.4.0"
 safetensors = ">=0.4.0"
+pillow = ">=10.3.0"
 
 [tool.poetry.group.dev.dependencies]
 types-protobuf = "4.24.0.20240129"
diff --git a/src/mistral_inference/__init__.py b/src/mistral_inference/__init__.py
index 9c73af26be70465839a5f43818dbab3f5c35571f..3e8d9f94621c6b29efab723e119a73a0dbe15089 100644
--- a/src/mistral_inference/__init__.py
+++ b/src/mistral_inference/__init__.py
@@ -1 +1 @@
-__version__ = "1.3.1"
+__version__ = "1.4.0"
diff --git a/src/mistral_inference/args.py b/src/mistral_inference/args.py
index dbf9f810dae61f6de0093e8b70f183514d6b4fd9..a94a2c605977cda5323230a924a83e53adbc51bd 100644
--- a/src/mistral_inference/args.py
+++ b/src/mistral_inference/args.py
@@ -7,6 +7,19 @@ from mistral_inference.lora import LoraArgs
 from mistral_inference.moe import MoeArgs
 
 
+@dataclass
+class VisionEncoderArgs:
+    hidden_size: int
+    num_channels: int
+    image_size: int
+    patch_size: int
+    intermediate_size: int
+    num_hidden_layers: int
+    num_attention_heads: int
+    rope_theta: float = 1e4  # for rope-2D
+    image_token_id: int = 10
+
+
 @dataclass
 class TransformerArgs(Serializable):
     dim: int
@@ -28,7 +41,9 @@ class TransformerArgs(Serializable):
     lora: Optional[LoraArgs] = None
     model_type: str = "transformer"
 
-    def __post_init__(self):
+    vision_encoder: Optional[VisionEncoderArgs] = None
+
+    def __post_init__(self) -> None:
         assert self.model_type == "transformer", self.model_type
 
 
@@ -45,5 +60,5 @@ class MambaArgs(Serializable):
     tie_embeddings: bool
     model_type: str = "mamba"
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         assert self.model_type == "mamba", self.model_type
diff --git a/src/mistral_inference/generate.py b/src/mistral_inference/generate.py
index 4d2dbe88e130453b2f7503f204bbc0d431198ccc..1e906b3c136f259bce0a6f7b8ec1704a86797bda 100644
--- a/src/mistral_inference/generate.py
+++ b/src/mistral_inference/generate.py
@@ -1,5 +1,6 @@
 from typing import List, Optional, Tuple
 
+import numpy as np
 import torch
 
 from mistral_inference.cache import BufferCache
@@ -43,12 +44,21 @@ def generate_mamba(
 def generate(
     encoded_prompts: List[List[int]],
     model: Transformer,
+    images: List[List[np.ndarray]] = [],
     *,
     max_tokens: int,
     temperature: float,
     chunk_size: Optional[int] = None,
     eos_id: Optional[int] = None,
 ) -> Tuple[List[List[int]], List[List[float]]]:
+    images_torch: List[List[torch.Tensor]] = []
+    if images:
+        assert chunk_size is None
+        images_torch = [
+            [torch.tensor(im, device=model.device, dtype=model.dtype) for im in images_for_sample]
+            for images_for_sample in images
+        ]
+
     model = model.eval()
     B, V = len(encoded_prompts), model.args.vocab_size
 
@@ -75,12 +85,15 @@ def generate(
     if chunk_size is None:
         chunk_size = max_prompt_len
 
+    flattened_images: List[torch.Tensor] = sum(images_torch, [])
+
     # Encode prompt by chunks
     for s in range(0, max_prompt_len, chunk_size):
         prompt_chunks = [p[s : s + chunk_size] for p in encoded_prompts]
         assert all(len(p) > 0 for p in prompt_chunks)
         prelogits = model.forward(
             torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
+            images=flattened_images,
             seqlens=[len(p) for p in prompt_chunks],
             cache=cache,
         )
diff --git a/src/mistral_inference/main.py b/src/mistral_inference/main.py
index 74743e8b63ba451c2234b9bc242a08be7a6d7810..28b8f7a72a922be6dc5a983c80a72aaf0e7ec433 100644
--- a/src/mistral_inference/main.py
+++ b/src/mistral_inference/main.py
@@ -3,19 +3,31 @@ import logging
 import os
 import warnings
 from pathlib import Path
-from typing import List, Optional, Type, Union
+from typing import List, Optional, Tuple, Type, Union
 
 import fire  # type: ignore
 import torch
 import torch.distributed as dist
-from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
+from mistral_common.protocol.instruct.messages import (
+    AssistantMessage,
+    ContentChunk,
+    ImageChunk,
+    ImageURLChunk,
+    TextChunk,
+    UserMessage,
+)
 from mistral_common.protocol.instruct.request import ChatCompletionRequest
 from mistral_common.tokens.tokenizers.base import Tokenizer
 from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
-from mistral_common.tokens.tokenizers.tekken import Tekkenizer, SpecialTokenPolicy
 from mistral_common.tokens.tokenizers.sentencepiece import is_sentencepiece
-from mistral_common.tokens.tokenizers.tekken import is_tekken
-
+from mistral_common.tokens.tokenizers.tekken import (
+    SpecialTokenPolicy,
+    Tekkenizer,
+    is_tekken,
+)
+from PIL import Image
+
+from mistral_inference.args import TransformerArgs
 from mistral_inference.generate import generate, generate_mamba
 from mistral_inference.mamba import Mamba
 from mistral_inference.transformer import Transformer
@@ -62,6 +74,31 @@ def pad_and_convert_to_tensor(list_of_lists: List[List[int]], pad_id: int) -> Li
     return padded_lists
 
 
+def _get_multimodal_input() -> Tuple[UserMessage, bool]:
+    chunks: List[ContentChunk] = []
+
+    response = input("Text prompt: ")
+    if response:
+        chunks.append(TextChunk(text=response))
+
+    print("[You can input zero, one or more images now.]")
+    while True:
+        did_something = False
+        response = input("Image path or url [Leave empty and press enter to finish image input]: ")
+        if response:
+            if Path(response).is_file():
+                chunks.append(ImageChunk(image=Image.open(response)))
+            else:
+                assert response.startswith("http"), f"{response} does not seem to be a valid url."
+                chunks.append(ImageURLChunk(image_url=response))
+            did_something = True
+
+        if not did_something:
+            break
+
+    return UserMessage(content=chunks), not chunks
+
+
 def interactive(
     model_path: str,
     max_tokens: int = 35,
@@ -85,6 +122,10 @@ def interactive(
 
     model_cls = get_model_cls(model_path)
     model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)
+    is_multimodal = isinstance(model.args, TransformerArgs) and model.args.vision_encoder is not None
+
+    if is_multimodal:
+        assert instruct, "Multimodal models should only be used in instruct mode"
 
     # load LoRA
     if lora_path is not None:
@@ -95,17 +136,27 @@ def interactive(
 
     while True:
         if should_print:
-            user_input = input("Prompt: ")
+            if not is_multimodal:
+                user_input = input("Prompt: ")
 
             if instruct:
-                messages += [UserMessage(content=user_input)]
+                if is_multimodal:
+                    mm_input, finished = _get_multimodal_input()
+                    if finished:
+                        break
+                    messages += [mm_input]
+                else:
+                    messages += [UserMessage(content=user_input)]
                 chat_completion_request = ChatCompletionRequest(messages=messages)
 
-                tokens = mistral_tokenizer.encode_chat_completion(chat_completion_request).tokens
+                tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request)
+                tokens = tokenized.tokens
+                images = tokenized.images
             else:
                 prompt += user_input
 
                 tokens = tokenizer.encode(prompt, bos=True, eos=False)
+                images = []
 
             length_tensor = torch.tensor([len(tokens)], dtype=torch.int)
         else:
@@ -121,6 +172,7 @@ def interactive(
         generated_tokens, _ = generate_fn(  # type: ignore[operator]
             [tokens],
             model,
+            [images],
             max_tokens=max_tokens,
             temperature=temperature,
             eos_id=tokenizer.eos_id,
diff --git a/src/mistral_inference/rope.py b/src/mistral_inference/rope.py
index 054724302062aa867b26b26ce817f37a9d83cae1..29749ff83669165698c6d08d201794d786e777cf 100644
--- a/src/mistral_inference/rope.py
+++ b/src/mistral_inference/rope.py
@@ -18,6 +18,34 @@ def apply_rotary_emb(
     xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
     xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
     freqs_cis = freqs_cis[:, None, :]
-    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
-    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
+    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
+    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
     return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+def precompute_freqs_cis_2d(
+    dim: int,
+    height: int,
+    width: int,
+    theta: float,
+) -> torch.Tensor:
+    """
+    freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
+        (height, width) position tuples
+    """
+    # (dim / 2) frequency bases
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
+
+    h = torch.arange(height, device=freqs.device)
+    w = torch.arange(width, device=freqs.device)
+
+    freqs_h = torch.outer(h, freqs[::2]).float()
+    freqs_w = torch.outer(w, freqs[1::2]).float()
+    freqs_2d = torch.cat(
+        [
+            freqs_h[:, None, :].repeat(1, width, 1),
+            freqs_w[None, :, :].repeat(height, 1, 1),
+        ],
+        dim=-1,
+    )
+    return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py
index 57546845bdd1e1847ccdc5b40ca19c7874b1e787..9c9aebec3ad86ff3c50945652f218bd6c5229cdd 100644
--- a/src/mistral_inference/transformer.py
+++ b/src/mistral_inference/transformer.py
@@ -2,25 +2,20 @@ import json
 import logging
 import math
 from dataclasses import dataclass
-from functools import partial
 from pathlib import Path
-from typing import Any, List, Mapping, Optional, Tuple, Type, Union
+from typing import Any, List, Mapping, Optional, Union
 
 import safetensors.torch
 import torch
 from torch import nn
-from xformers.ops.fmha import memory_efficient_attention  # type: ignore
 
 from mistral_inference.args import TransformerArgs
-from mistral_inference.cache import (
-    BufferCache,
-    CacheInputMetadata,
-    CacheView,
-)
-from mistral_inference.lora import LoRALinear, LoRALoaderMixin
+from mistral_inference.cache import BufferCache, CacheInputMetadata
+from mistral_inference.lora import LoRALoaderMixin
 from mistral_inference.model import ModelBase
-from mistral_inference.moe import MoeLayer
-from mistral_inference.rope import apply_rotary_emb, precompute_freqs_cis
+from mistral_inference.rope import precompute_freqs_cis
+from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
+from mistral_inference.vision_encoder import VisionLanguageAdapter, VisionTransformer
 
 
 @dataclass
@@ -35,131 +30,6 @@ class SimpleInputMetadata:
         )
 
 
-def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
-    keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
-    values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
-    return keys, values
-
-
-def maybe_lora(args: TransformerArgs) -> Union[Type[nn.Linear], partial[LoRALinear]]:
-    if args.lora is None:
-        return nn.Linear
-    else:
-        return partial(LoRALinear, rank=args.lora.rank, scaling=args.lora.scaling)
-
-
-class Attention(nn.Module):
-    def __init__(self, args: TransformerArgs):
-        super().__init__()
-        self.args = args
-
-        self.n_heads: int = args.n_heads
-        self.head_dim: int = args.head_dim
-        self.n_kv_heads: int = args.n_kv_heads
-
-        self.repeats = self.n_heads // self.n_kv_heads
-
-        self.scale = self.args.head_dim**-0.5
-
-        MaybeLora = maybe_lora(args)
-        self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False)
-        self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
-        self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
-        self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        freqs_cis: torch.Tensor,
-        cache: Optional[CacheView],
-    ) -> torch.Tensor:
-        seqlen_sum, _ = x.shape
-
-        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
-        xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
-        xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
-        xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
-        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
-
-        if cache is None:
-            key, val = xk, xv
-        elif cache.prefill:
-            key, val = cache.interleave_kv(xk, xv)
-            cache.update(xk, xv)
-        else:
-            cache.update(xk, xv)
-            key, val = cache.key, cache.value
-            key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
-            val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
-
-        # Repeat keys and values to match number of query heads
-        key, val = repeat_kv(key, val, self.repeats, dim=1)
-
-        # xformers requires (B=1, S, H, D)
-        xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
-        output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)
-        output = output.view(seqlen_sum, self.n_heads * self.head_dim)
-
-        assert isinstance(output, torch.Tensor)
-
-        return self.wo(output)  # type: ignore
-
-
-class FeedForward(nn.Module):
-    def __init__(self, args: TransformerArgs):
-        super().__init__()
-
-        MaybeLora = maybe_lora(args)
-        self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False)
-        self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False)
-        self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))  # type: ignore
-
-
-class RMSNorm(torch.nn.Module):
-    def __init__(self, dim: int, eps: float = 1e-6):
-        super().__init__()
-        self.eps = eps
-        self.weight = nn.Parameter(torch.ones(dim))
-
-    def _norm(self, x: torch.Tensor) -> torch.Tensor:
-        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        output = self._norm(x.float()).type_as(x)
-        return output * self.weight
-
-
-class TransformerBlock(nn.Module):
-    def __init__(self, args: TransformerArgs):
-        super().__init__()
-        self.n_heads = args.n_heads
-        self.dim = args.dim
-        self.attention = Attention(args)
-        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
-        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
-        self.args = args
-
-        self.feed_forward: nn.Module
-        if args.moe is not None:
-            self.feed_forward = MoeLayer(
-                experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
-                gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
-                moe_args=args.moe,
-            )
-        else:
-            self.feed_forward = FeedForward(args=args)
-
-    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]) -> torch.Tensor:
-        r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
-        h = x + r
-        r = self.feed_forward.forward(self.ffn_norm(h))
-        out = h + r
-        return out
-
-
 class Transformer(ModelBase, LoRALoaderMixin):
     def __init__(
         self,
@@ -182,11 +52,29 @@ class Transformer(ModelBase, LoRALoaderMixin):
         self.output: Optional[nn.Linear] = None
         if pipeline_rank == 0:
             self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
+
+            self.vision_encoder: Optional[VisionTransformer] = None
+            self.vision_language_adapter: Optional[VisionLanguageAdapter] = None
+            if args.vision_encoder is not None:
+                self.vision_encoder = VisionTransformer(args.vision_encoder)
+                self.vision_language_adapter = VisionLanguageAdapter(args.vision_encoder.hidden_size, args.dim)
         if pipeline_rank == num_pipeline_ranks - 1:
             self.norm = RMSNorm(args.dim, eps=args.norm_eps)
             self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
         # Initialize all layers but slice off those not of this rank.
-        layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
+        layers = [
+            TransformerBlock(
+                dim=args.dim,
+                hidden_dim=args.hidden_dim,
+                n_heads=args.n_heads,
+                n_kv_heads=args.n_kv_heads,
+                head_dim=args.head_dim,
+                norm_eps=args.norm_eps,
+                lora=args.lora,
+                moe=args.moe,
+            )
+            for _ in range(args.n_layers)
+        ]
         num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks)
         offset = self.pipeline_rank * num_layers_per_rank
         end = min(self.n_layers, offset + num_layers_per_rank)
@@ -215,11 +103,41 @@ class Transformer(ModelBase, LoRALoaderMixin):
             self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device)
         return self._precomputed_freqs_cis
 
+    def embed_vision_language_features(self, input_ids: torch.Tensor, images: List[torch.tensor]) -> torch.Tensor:  # type: ignore[valid-type]
+        assert self.tok_embeddings is not None
+        assert self.vision_encoder is not None
+        assert self.vision_language_adapter is not None
+        assert self.args.vision_encoder is not None
+
+        text_locations = input_ids != self.args.vision_encoder.image_token_id
+        image_locations = input_ids == self.args.vision_encoder.image_token_id
+        text_features = self.tok_embeddings(input_ids[text_locations])
+        image_features = self.vision_language_adapter(self.vision_encoder(images))
+
+        seq_len = input_ids.shape[0]
+        N_txt, D_txt = text_features.shape
+        N_img, D_img = image_features.shape
+
+        assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}"
+        assert (
+            seq_len == N_txt + N_img
+        ), f"seq_len {seq_len} should be equal to N_txt + N_img {(N_txt, N_img, image_locations.sum().item())}"
+
+        combined_features = torch.empty(
+            (seq_len, D_txt),
+            dtype=text_features.dtype,
+            device=text_features.device,
+        )
+        combined_features[text_locations, :] = text_features
+        combined_features[image_locations, :] = image_features
+        return combined_features
+
     def forward_partial(
         self,
         input_ids: torch.Tensor,
         seqlens: List[int],
         cache: Optional[BufferCache] = None,
+        images: Optional[List[torch.Tensor]] = None,
     ) -> torch.Tensor:
         """Local forward pass.
 
@@ -241,7 +159,10 @@ class Transformer(ModelBase, LoRALoaderMixin):
 
         if self.pipeline_rank == 0:
             assert self.tok_embeddings is not None
-            h = self.tok_embeddings(input_ids)
+            if self.vision_encoder is not None and images:
+                h = self.embed_vision_language_features(input_ids, images)
+            else:
+                h = self.tok_embeddings(input_ids)
         else:
             h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype)
             torch.distributed.recv(h, src=self.pipeline_rank - 1)
@@ -261,7 +182,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
             cache.update_seqlens(seqlens)
         if self.pipeline_rank < self.num_pipeline_ranks - 1:
             torch.distributed.send(h, dst=self.pipeline_rank + 1)
-            return h  # type: ignore
+            return h
         else:
             # Last rank has a final normalization step.
             assert self.norm is not None
@@ -272,8 +193,9 @@ class Transformer(ModelBase, LoRALoaderMixin):
         input_ids: torch.Tensor,
         seqlens: List[int],
         cache: Optional[BufferCache] = None,
+        images: Optional[List[torch.Tensor]] = None,
     ) -> torch.Tensor:
-        h = self.forward_partial(input_ids, seqlens, cache=cache)
+        h = self.forward_partial(input_ids, seqlens, cache=cache, images=images)
         if self.pipeline_rank < self.num_pipeline_ranks - 1:
             # ignore the intermediate activations as we'll get the final output from
             # the last stage
@@ -320,6 +242,9 @@ class Transformer(ModelBase, LoRALoaderMixin):
                         self.pipeline_rank,
                     )
                     skipped.add(k)
+            elif k.startswith("vision_encoder") or k.startswith("vision_language_adapter"):
+                assert not self.pipeline_rank
+                state_to_load[k] = v
             else:
                 raise ValueError(f"Unexpected key {k}")
         assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys()))
diff --git a/src/mistral_inference/transformer_layers.py b/src/mistral_inference/transformer_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ee23f5669f5a2566703bb190344981cb5745aca
--- /dev/null
+++ b/src/mistral_inference/transformer_layers.py
@@ -0,0 +1,169 @@
+from functools import partial
+from typing import Optional, Tuple, Type, Union
+
+import torch
+from torch import nn
+from xformers.ops.fmha import memory_efficient_attention  # type: ignore
+from xformers.ops.fmha.attn_bias import BlockDiagonalMask
+
+from mistral_inference.args import LoraArgs
+from mistral_inference.cache import CacheView
+from mistral_inference.lora import LoRALinear
+from mistral_inference.moe import MoeArgs, MoeLayer
+from mistral_inference.rope import apply_rotary_emb
+
+
+def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
+    keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
+    values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
+    return keys, values
+
+
+def maybe_lora(
+    lora_args: Optional[LoraArgs],
+) -> Union[Type[nn.Linear], partial[LoRALinear]]:
+    if lora_args is None:
+        return nn.Linear
+    else:
+        return partial(LoRALinear, rank=lora_args.rank, scaling=lora_args.scaling)
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        n_heads: int,
+        head_dim: int,
+        n_kv_heads: int,
+        lora: Optional[LoraArgs] = None,
+    ):
+        super().__init__()
+
+        self.n_heads: int = n_heads
+        self.head_dim: int = head_dim
+        self.n_kv_heads: int = n_kv_heads
+
+        self.repeats = self.n_heads // self.n_kv_heads
+
+        self.scale = self.head_dim**-0.5
+
+        MaybeLora = maybe_lora(lora)
+        self.wq = MaybeLora(dim, n_heads * head_dim, bias=False)
+        self.wk = MaybeLora(dim, n_kv_heads * head_dim, bias=False)
+        self.wv = MaybeLora(dim, n_kv_heads * head_dim, bias=False)
+        self.wo = MaybeLora(n_heads * head_dim, dim, bias=False)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        freqs_cis: torch.Tensor,
+        cache: Optional[CacheView] = None,
+        mask: Optional[BlockDiagonalMask] = None,
+    ) -> torch.Tensor:
+        assert mask is None or cache is None
+        seqlen_sum, _ = x.shape
+
+        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+        xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
+        xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
+        xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
+        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+        if cache is None:
+            key, val = xk, xv
+        elif cache.prefill:
+            key, val = cache.interleave_kv(xk, xv)
+            cache.update(xk, xv)
+        else:
+            cache.update(xk, xv)
+            key, val = cache.key, cache.value
+            key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
+            val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
+
+        # Repeat keys and values to match number of query heads
+        key, val = repeat_kv(key, val, self.repeats, dim=1)
+
+        # xformers requires (B=1, S, H, D)
+        xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
+        output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
+        output = output.view(seqlen_sum, self.n_heads * self.head_dim)
+
+        assert isinstance(output, torch.Tensor)
+
+        return self.wo(output)  # type: ignore
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim: int, hidden_dim: int, lora: Optional[LoraArgs] = None):
+        super().__init__()
+
+        MaybeLora = maybe_lora(lora)
+        self.w1 = MaybeLora(dim, hidden_dim, bias=False)
+        self.w2 = MaybeLora(hidden_dim, dim, bias=False)
+        self.w3 = MaybeLora(dim, hidden_dim, bias=False)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))  # type: ignore
+
+
+class RMSNorm(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x: torch.Tensor) -> torch.Tensor:
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+class TransformerBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        hidden_dim: int,
+        n_heads: int,
+        n_kv_heads: int,
+        head_dim: int,
+        norm_eps: float,
+        lora: Optional[LoraArgs] = None,
+        moe: Optional[MoeArgs] = None,
+    ):
+        super().__init__()
+        self.n_heads = n_heads
+        self.dim = dim
+        self.attention = Attention(
+            dim=dim,
+            n_heads=n_heads,
+            head_dim=head_dim,
+            n_kv_heads=n_kv_heads,
+            lora=lora,
+        )
+        self.attention_norm = RMSNorm(dim, eps=norm_eps)
+        self.ffn_norm = RMSNorm(dim, eps=norm_eps)
+
+        self.feed_forward: nn.Module
+        if moe is not None:
+            self.feed_forward = MoeLayer(
+                experts=[FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora) for _ in range(moe.num_experts)],
+                gate=nn.Linear(dim, moe.num_experts, bias=False),
+                moe_args=moe,
+            )
+        else:
+            self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        freqs_cis: torch.Tensor,
+        cache: Optional[CacheView] = None,
+        mask: Optional[BlockDiagonalMask] = None,
+    ) -> torch.Tensor:
+        r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
+        h = x + r
+        r = self.feed_forward.forward(self.ffn_norm(h))
+        out = h + r
+        return out
diff --git a/src/mistral_inference/vision_encoder.py b/src/mistral_inference/vision_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..833cbb68f9ecd6800cab0c0d5c206c0067c6cee0
--- /dev/null
+++ b/src/mistral_inference/vision_encoder.py
@@ -0,0 +1,146 @@
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+from xformers.ops.fmha.attn_bias import BlockDiagonalMask
+
+from mistral_inference.args import VisionEncoderArgs
+from mistral_inference.rope import precompute_freqs_cis_2d
+from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
+
+
+def position_meshgrid(
+    patch_embeds_list: list[torch.Tensor],
+) -> torch.Tensor:
+    positions = torch.cat(
+        [
+            torch.stack(
+                torch.meshgrid(
+                    torch.arange(p.shape[-2]),
+                    torch.arange(p.shape[-1]),
+                    indexing="ij",
+                ),
+                dim=-1,
+            ).reshape(-1, 2)
+            for p in patch_embeds_list
+        ]
+    )
+    return positions
+
+
+class VisionTransformer(nn.Module):
+    def __init__(self, args: VisionEncoderArgs):
+        super().__init__()
+        self.args = args
+        self.patch_conv = nn.Conv2d(
+            in_channels=args.num_channels,
+            out_channels=args.hidden_size,
+            kernel_size=args.patch_size,
+            stride=args.patch_size,
+            bias=False,
+        )
+        self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
+        self.transformer = VisionTransformerBlocks(args)
+
+        head_dim = self.args.hidden_size // self.args.num_attention_heads
+        assert head_dim % 2 == 0, "ROPE requires even head_dim"
+        self._freqs_cis: Optional[torch.Tensor] = None
+
+    @property
+    def max_patches_per_side(self) -> int:
+        return self.args.image_size // self.args.patch_size
+
+    @property
+    def device(self) -> torch.device:
+        return next(self.parameters()).device
+
+    @property
+    def freqs_cis(self) -> torch.Tensor:
+        if self._freqs_cis is None:
+            self._freqs_cis = precompute_freqs_cis_2d(
+                dim=self.args.hidden_size // self.args.num_attention_heads,
+                height=self.max_patches_per_side,
+                width=self.max_patches_per_side,
+                theta=self.args.rope_theta,
+            )
+
+        if self._freqs_cis.device != self.device:
+            self._freqs_cis = self._freqs_cis.to(device=self.device)
+
+        return self._freqs_cis
+
+    def forward(
+        self,
+        images: List[torch.Tensor],
+    ) -> torch.Tensor:
+        """
+        Args:
+            images: list of N_img images of variable sizes, each of shape (C, H, W)
+
+        Returns:
+            image_features: tensor of token features for all tokens of all images of
+                shape (N_toks, D)
+        """
+        # pass images through initial convolution independently
+        patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
+
+        # flatten to a single sequence
+        patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
+        patch_embeds = self.ln_pre(patch_embeds)
+
+        # positional embeddings
+        positions = position_meshgrid(patch_embeds_list).to(self.device)
+        freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
+
+        # pass through Transformer with a block diagonal mask delimiting images
+        mask = BlockDiagonalMask.from_seqlens(
+            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
+        )
+        out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
+
+        # remove batch dimension of the single sequence
+        return out  # type: ignore[no-any-return]
+
+
+class VisionLanguageAdapter(nn.Module):
+    def __init__(self, in_dim: int, out_dim: int):
+        super().__init__()
+        self.w_in = nn.Linear(
+            in_dim,
+            out_dim,
+            bias=True,
+        )
+        self.gelu = nn.GELU()
+        self.w_out = nn.Linear(out_dim, out_dim, bias=True)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.w_out(self.gelu(self.w_in(x)))  # type: ignore[no-any-return]
+
+
+class VisionTransformerBlocks(nn.Module):
+    def __init__(self, args: VisionEncoderArgs):
+        super().__init__()
+        self.layers = torch.nn.ModuleList()
+        for _ in range(args.num_hidden_layers):
+            self.layers.append(
+                TransformerBlock(
+                    dim=args.hidden_size,
+                    hidden_dim=args.intermediate_size,
+                    n_heads=args.num_attention_heads,
+                    n_kv_heads=args.num_attention_heads,
+                    head_dim=args.hidden_size // args.num_attention_heads,
+                    norm_eps=1e-5,
+                )
+            )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        mask: BlockDiagonalMask,
+        freqs_cis: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        for layer in self.layers:
+            x = layer(x, mask=mask, freqs_cis=freqs_cis)
+        return x
+
+
diff --git a/tests/test_generate.py b/tests/test_generate.py
index 00bfad4efd209d33a0dea0938236f245ab37a7ae..e380d3e5cd22497a68afffb8ed4c98818799c62c 100644
--- a/tests/test_generate.py
+++ b/tests/test_generate.py
@@ -1,6 +1,8 @@
 from typing import List
 
+import numpy as np
 import torch
+from mistral_inference.args import VisionEncoderArgs
 from mistral_inference.generate import generate_mamba
 from mistral_inference.main import generate
 from mistral_inference.mamba import Mamba, MambaArgs
@@ -55,14 +57,73 @@ def test_generation_transformer():
     # concat generated and prompt
     encoded = [e + t for e, t in zip(encoded, toks)]
 
-    generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0)
+    generated, all_logprobs_new = generate(
+        encoded, model, temperature=0.0, max_tokens=0
+    )
+
+    assert generated == []
+
+    # Verify that logprobs are the same
+    assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new)
+    for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
+        assert all(
+            [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]
+        ), f"\n{lp_old}\n{lp_new}"
+
+    print("All tests passed.")
+
+
+def test_generation_pixtral():
+    torch.manual_seed(42)
+    gen = np.random.default_rng(seed=42)
+
+    sequences = ["1 2 2 2 2 4 5 6 7", "12 13 14", "2 2 2 2 7 8 9"]
+    images = [[gen.normal(size=(3, 4, 4))], [], [gen.normal(size=(3, 4, 4))]]
+    args = TransformerArgs(
+        dim=512,
+        n_layers=1,
+        head_dim=128,
+        hidden_dim=2048,
+        n_heads=4,
+        n_kv_heads=2,
+        norm_eps=1e-5,
+        vocab_size=32_000,
+        max_batch_size=len(sequences),
+        vision_encoder=VisionEncoderArgs(
+            hidden_size=128,
+            num_channels=3,
+            image_size=4,
+            patch_size=2,
+            intermediate_size=256,
+            num_hidden_layers=1,
+            num_attention_heads=2,
+            rope_theta=10000,
+            image_token_id=2,
+        ),
+    )
+    model = Transformer(args).to("cuda", dtype=torch.float32)
+    tokenizer = DebugTokenizer()
+
+    encoded = [tokenizer.encode(s, bos=True) for s in sequences]
+    toks, all_logprobs_old = generate(
+        encoded, model, images=images, temperature=0.0, max_tokens=7
+    )
+
+    # concat generated and prompt
+    encoded = [e + t for e, t in zip(encoded, toks)]
+
+    generated, all_logprobs_new = generate(
+        encoded, model, images=images, temperature=0.0, max_tokens=0
+    )
 
     assert generated == []
 
     # Verify that logprobs are the same
     assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new)
     for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
-        assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}"
+        assert all(
+            [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]
+        ), f"\n{lp_old}\n{lp_new}"
 
     print("All tests passed.")
 
@@ -86,7 +147,9 @@ def test_generation_mamba():
     tokenizer = DebugTokenizer()
 
     encoded = [tokenizer.encode(s, bos=True) for s in sequences]
-    toks, all_logprobs_old = generate_mamba(encoded, model, temperature=0.0, max_tokens=7)
+    toks, all_logprobs_old = generate_mamba(
+        encoded, model, temperature=0.0, max_tokens=7
+    )
 
     assert len(toks[0]) == 7
     assert toks == [[25574, 14821, 11843, 23698, 12735, 23522, 27542]]
@@ -119,8 +182,12 @@ def test_chunks_transformer():
     # concat generated and prompt
     encoded = [e + t for e, t in zip(encoded, toks)]
 
-    generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0, chunk_size=5)
+    generated, all_logprobs_new = generate(
+        encoded, model, temperature=0.0, max_tokens=0, chunk_size=5
+    )
     assert len(generated) == 0
 
     for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
-        assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}"
+        assert all(
+            [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]
+        ), f"\n{lp_old}\n{lp_new}"