This commit is contained in:
Arthur Zucker 2024-05-30 16:47:21 +02:00
parent 065cd1afcb
commit e1b0262a9e
3 changed files with 48 additions and 22 deletions

View File

@ -20,7 +20,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from transformers import PretrainedConfig from transformers import PretrainedConfig

View File

@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaForCausalLM, LlamaForCausalLM,
@ -422,18 +422,52 @@ class GemmaForCausalLM(LlamaForCausalLM):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?" "What is your favorite condiment?"
```""" ```"""
return super().forward( output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
input_ids, output_hidden_states = (
attention_mask, output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
position_ids, )
past_key_values, return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs_embeds,
labels, # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
use_cache, outputs = self.model(
output_attentions, input_ids=input_ids,
output_hidden_states, attention_mask=attention_mask,
return_dict, position_ids=position_ids,
cache_position, past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )

View File

@ -726,7 +726,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
_CONFIG_FOR_DOC = "GemmaConfig" _CONFIG_FOR_DOC = "GemmaConfig"
@ -1126,14 +1125,8 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None loss = None
if labels is not None: if labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n