diff --git a/src/mistral_inference/generate.py b/src/mistral_inference/generate.py index c9e35c5dbf6a1c0e14f00692880145f62c2f7744..4d2dbe88e130453b2f7503f204bbc0d431198ccc 100644 --- a/src/mistral_inference/generate.py +++ b/src/mistral_inference/generate.py @@ -112,7 +112,7 @@ def generate( next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8) if eos_id is not None: - is_finished = is_finished ^ (next_token == eos_id).cpu() + is_finished = is_finished | (next_token == eos_id).cpu() if is_finished.all(): break