diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py
index 2894c07be18450616da222c3a610e8f3adfff8a2..9c9aebec3ad86ff3c50945652f218bd6c5229cdd 100644
--- a/src/mistral_inference/transformer.py
+++ b/src/mistral_inference/transformer.py
@@ -14,7 +14,7 @@ from mistral_inference.cache import BufferCache, CacheInputMetadata
 from mistral_inference.lora import LoRALoaderMixin
 from mistral_inference.model import ModelBase
 from mistral_inference.rope import precompute_freqs_cis
-from mistral_inference.transformer_utils import RMSNorm, TransformerBlock
+from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
 from mistral_inference.vision_encoder import VisionLanguageAdapter, VisionTransformer
 
 
diff --git a/src/mistral_inference/transformer_utils.py b/src/mistral_inference/transformer_layers.py
similarity index 100%
rename from src/mistral_inference/transformer_utils.py
rename to src/mistral_inference/transformer_layers.py
diff --git a/src/mistral_inference/vision_encoder.py b/src/mistral_inference/vision_encoder.py
index fcca5305993f65c3f592594da4a8a82af2130a12..833cbb68f9ecd6800cab0c0d5c206c0067c6cee0 100644
--- a/src/mistral_inference/vision_encoder.py
+++ b/src/mistral_inference/vision_encoder.py
@@ -6,34 +6,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
 
 from mistral_inference.args import VisionEncoderArgs
 from mistral_inference.rope import precompute_freqs_cis_2d
-from mistral_inference.transformer_utils import RMSNorm, TransformerBlock
-
-
-class Transformer(nn.Module):
-    def __init__(self, args: VisionEncoderArgs):
-        super().__init__()
-        self.layers = torch.nn.ModuleList()
-        for _ in range(args.num_hidden_layers):
-            self.layers.append(
-                TransformerBlock(
-                    dim=args.hidden_size,
-                    hidden_dim=args.intermediate_size,
-                    n_heads=args.num_attention_heads,
-                    n_kv_heads=args.num_attention_heads,
-                    head_dim=args.hidden_size // args.num_attention_heads,
-                    norm_eps=1e-5,
-                )
-            )
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        mask: BlockDiagonalMask,
-        freqs_cis: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        for layer in self.layers:
-            x = layer(x, mask=mask, freqs_cis=freqs_cis)
-        return x
+from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
 
 
 def position_meshgrid(
@@ -67,7 +40,7 @@ class VisionTransformer(nn.Module):
             bias=False,
         )
         self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
-        self.transformer = Transformer(args)
+        self.transformer = VisionTransformerBlocks(args)
 
         head_dim = self.args.hidden_size // self.args.num_attention_heads
         assert head_dim % 2 == 0, "ROPE requires even head_dim"
@@ -142,3 +115,32 @@ class VisionLanguageAdapter(nn.Module):
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         return self.w_out(self.gelu(self.w_in(x)))  # type: ignore[no-any-return]
+
+
+class VisionTransformerBlocks(nn.Module):
+    def __init__(self, args: VisionEncoderArgs):
+        super().__init__()
+        self.layers = torch.nn.ModuleList()
+        for _ in range(args.num_hidden_layers):
+            self.layers.append(
+                TransformerBlock(
+                    dim=args.hidden_size,
+                    hidden_dim=args.intermediate_size,
+                    n_heads=args.num_attention_heads,
+                    n_kv_heads=args.num_attention_heads,
+                    head_dim=args.hidden_size // args.num_attention_heads,
+                    norm_eps=1e-5,
+                )
+            )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        mask: BlockDiagonalMask,
+        freqs_cis: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        for layer in self.layers:
+            x = layer(x, mask=mask, freqs_cis=freqs_cis)
+        return x
+
+