Move misplaced line (#29117)

Move misplaced line, improve code comment
This commit is contained in:
Erich Schubert 2024-02-20 02:24:48 +01:00 committed by GitHub
parent 9094abe8dc
commit a7ff2f23a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -1176,11 +1176,11 @@ class MistralForCausalLM(MistralPreTrainedModel):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
# Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: