fix
This commit is contained in:
parent
065cd1afcb
commit
e1b0262a9e
|
@ -20,7 +20,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
|
@ -422,18 +422,52 @@ class GemmaForCausalLM(LlamaForCausalLM):
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"What is your favorite condiment?"
|
"What is your favorite condiment?"
|
||||||
```"""
|
```"""
|
||||||
return super().forward(
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
input_ids,
|
output_hidden_states = (
|
||||||
attention_mask,
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
position_ids,
|
)
|
||||||
past_key_values,
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
inputs_embeds,
|
|
||||||
labels,
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
use_cache,
|
outputs = self.model(
|
||||||
output_attentions,
|
input_ids=input_ids,
|
||||||
output_hidden_states,
|
attention_mask=attention_mask,
|
||||||
return_dict,
|
position_ids=position_ids,
|
||||||
cache_position,
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -726,7 +726,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "GemmaConfig"
|
_CONFIG_FOR_DOC = "GemmaConfig"
|
||||||
|
|
||||||
|
|
||||||
|
@ -1126,14 +1125,8 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
|
||||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
logits = torch.cat(logits, dim=-1)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
|
|
Loading…
Reference in New Issue