parent
9094abe8dc
commit
a7ff2f23a0
|
@ -1176,11 +1176,11 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
|||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
# Ensure tensors are on the same device
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
|
|
Loading…
Reference in New Issue