diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py index 2894c07be18450616da222c3a610e8f3adfff8a2..9c9aebec3ad86ff3c50945652f218bd6c5229cdd 100644 --- a/src/mistral_inference/transformer.py +++ b/src/mistral_inference/transformer.py @@ -14,7 +14,7 @@ from mistral_inference.cache import BufferCache, CacheInputMetadata from mistral_inference.lora import LoRALoaderMixin from mistral_inference.model import ModelBase from mistral_inference.rope import precompute_freqs_cis -from mistral_inference.transformer_utils import RMSNorm, TransformerBlock +from mistral_inference.transformer_layers import RMSNorm, TransformerBlock from mistral_inference.vision_encoder import VisionLanguageAdapter, VisionTransformer diff --git a/src/mistral_inference/transformer_utils.py b/src/mistral_inference/transformer_layers.py similarity index 100% rename from src/mistral_inference/transformer_utils.py rename to src/mistral_inference/transformer_layers.py diff --git a/src/mistral_inference/vision_encoder.py b/src/mistral_inference/vision_encoder.py index fcca5305993f65c3f592594da4a8a82af2130a12..833cbb68f9ecd6800cab0c0d5c206c0067c6cee0 100644 --- a/src/mistral_inference/vision_encoder.py +++ b/src/mistral_inference/vision_encoder.py @@ -6,34 +6,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask from mistral_inference.args import VisionEncoderArgs from mistral_inference.rope import precompute_freqs_cis_2d -from mistral_inference.transformer_utils import RMSNorm, TransformerBlock - - -class Transformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): - super().__init__() - self.layers = torch.nn.ModuleList() - for _ in range(args.num_hidden_layers): - self.layers.append( - TransformerBlock( - dim=args.hidden_size, - hidden_dim=args.intermediate_size, - n_heads=args.num_attention_heads, - n_kv_heads=args.num_attention_heads, - head_dim=args.hidden_size // args.num_attention_heads, - norm_eps=1e-5, - ) - ) - - def forward( - self, - x: torch.Tensor, - mask: BlockDiagonalMask, - freqs_cis: Optional[torch.Tensor], - ) -> torch.Tensor: - for layer in self.layers: - x = layer(x, mask=mask, freqs_cis=freqs_cis) - return x +from mistral_inference.transformer_layers import RMSNorm, TransformerBlock def position_meshgrid( @@ -67,7 +40,7 @@ class VisionTransformer(nn.Module): bias=False, ) self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) - self.transformer = Transformer(args) + self.transformer = VisionTransformerBlocks(args) head_dim = self.args.hidden_size // self.args.num_attention_heads assert head_dim % 2 == 0, "ROPE requires even head_dim" @@ -142,3 +115,32 @@ class VisionLanguageAdapter(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return] + + +class VisionTransformerBlocks(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append( + TransformerBlock( + dim=args.hidden_size, + hidden_dim=args.intermediate_size, + n_heads=args.num_attention_heads, + n_kv_heads=args.num_attention_heads, + head_dim=args.hidden_size // args.num_attention_heads, + norm_eps=1e-5, + ) + ) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: Optional[torch.Tensor], + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + +