Skip to content
Snippets Groups Projects
Commit b9524508 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

WIP

parent 5d4cc685
No related branches found
No related tags found
No related merge requests found
......@@ -36,6 +36,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
args: TransformerArgs,
pipeline_rank: int = 0,
num_pipeline_ranks: int = 1,
softmax_fp32: bool = True,
):
super().__init__()
self.args = args
......@@ -46,6 +47,8 @@ class Transformer(ModelBase, LoRALoaderMixin):
assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
self.pipeline_rank = pipeline_rank
self.num_pipeline_ranks = num_pipeline_ranks
self.softmax_fp32 = softmax_fp32
# Modules specific to some ranks:
self.tok_embeddings: Optional[nn.Embedding] = None
self.norm: Optional[RMSNorm] = None
......@@ -207,7 +210,11 @@ class Transformer(ModelBase, LoRALoaderMixin):
outs = self.output(h)
if self.num_pipeline_ranks > 1:
torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
return outs.float()
if self.softmax_fp32:
return outs.float()
else:
return outs
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None:
state_to_load = {}
......@@ -259,6 +266,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
num_pipeline_ranks: int = 1,
device: Union[torch.device, str] = "cuda",
dtype: Optional[torch.dtype] = None,
softmax_fp32: bool = True,
) -> "Transformer":
with open(Path(folder) / "params.json", "r") as f:
model_args = TransformerArgs.from_dict(json.load(f))
......@@ -272,6 +280,7 @@ class Transformer(ModelBase, LoRALoaderMixin):
model_args,
pipeline_rank=pipeline_rank,
num_pipeline_ranks=num_pipeline_ranks,
softmax_fp32=softmax_fp32,
)
pt_model_file = Path(folder) / "consolidated.00.pth"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment