From b9524508e98db42d155b98fdf6494ebc70e3934e Mon Sep 17 00:00:00 2001
From: Patrick von Platen <patrick.v.platen@gmail.com>
Date: Wed, 16 Oct 2024 13:58:03 +0000
Subject: [PATCH] WIP

---
 src/mistral_inference/transformer.py | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py
index a53195f..cb782dd 100644
--- a/src/mistral_inference/transformer.py
+++ b/src/mistral_inference/transformer.py
@@ -36,6 +36,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
         args: TransformerArgs,
         pipeline_rank: int = 0,
         num_pipeline_ranks: int = 1,
+        softmax_fp32: bool = True,
     ):
         super().__init__()
         self.args = args
@@ -46,6 +47,8 @@ class Transformer(ModelBase, LoRALoaderMixin):
         assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
         self.pipeline_rank = pipeline_rank
         self.num_pipeline_ranks = num_pipeline_ranks
+        self.softmax_fp32 = softmax_fp32
+
         # Modules specific to some ranks:
         self.tok_embeddings: Optional[nn.Embedding] = None
         self.norm: Optional[RMSNorm] = None
@@ -207,7 +210,11 @@ class Transformer(ModelBase, LoRALoaderMixin):
             outs = self.output(h)
         if self.num_pipeline_ranks > 1:
             torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
-        return outs.float()
+
+        if self.softmax_fp32:
+            return outs.float()
+        else:
+            return outs
 
     def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None:
         state_to_load = {}
@@ -259,6 +266,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
         num_pipeline_ranks: int = 1,
         device: Union[torch.device, str] = "cuda",
         dtype: Optional[torch.dtype] = None,
+        softmax_fp32: bool = True,
     ) -> "Transformer":
         with open(Path(folder) / "params.json", "r") as f:
             model_args = TransformerArgs.from_dict(json.load(f))
@@ -272,6 +280,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
                 model_args,
                 pipeline_rank=pipeline_rank,
                 num_pipeline_ranks=num_pipeline_ranks,
+                softmax_fp32=softmax_fp32,
             )
 
         pt_model_file = Path(folder) / "consolidated.00.pth"
-- 
GitLab