From 7ea9bcd3dcc40af84f383bd64fbfeb520eecce11 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 28 May 2024 16:13:43 +0200 Subject: [PATCH] better merging strategy --- .../models/gemma/configuration_gemma.py | 100 +++++------------- src/transformers/models/gemma/diff_gemma.py | 89 +++++++--------- .../models/gemma/modeling_gemma.py | 9 +- utils/diff_model_converter.py | 16 ++- 4 files changed, 90 insertions(+), 124 deletions(-) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index bd6508cb52..540faa77a8 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -77,74 +77,6 @@ from transformers import PretrainedConfig class GemmaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma-7B. - e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 256000): - Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`GemmaModel`] - hidden_size (`int`, *optional*, defaults to 3072): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 24576): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 28): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 16): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 256): - The attention head dimension. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The legacy activation function. It is overwritten by the `hidden_activation`. - hidden_activation (`str` or `function`, *optional*): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. - max_position_embeddings (`int`, *optional*, defaults to 8192): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - ```python - >>> from transformers import GemmaModel, GemmaConfig - >>> # Initializing a Gemma gemma-7b style configuration - >>> configuration = GemmaConfig() - >>> # Initializing a model from the gemma-7b style configuration - >>> model = GemmaModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "gemma" - keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, @@ -168,7 +100,6 @@ class GemmaConfig(PretrainedConfig): rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, - **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -176,16 +107,23 @@ class GemmaConfig(PretrainedConfig): self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - self.head_dim = head_dim + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act - self.hidden_activation = hidden_activation self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias super().__init__( pad_token_id=pad_token_id, @@ -194,3 +132,23 @@ class GemmaConfig(PretrainedConfig): tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index c0f6ac17fb..9402048452 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from transformers import PretrainedConfig from transformers.models.llama.modeling_llama import ( @@ -32,7 +31,7 @@ from transformers.models.llama.modeling_llama import ( apply_rotary_pos_emb, repeat_kv, ) - +from transformers.models.llama.configuration_llama import LlamaConfig from ...activations import ACT2FN from ...cache_utils import Cache from ...modeling_outputs import CausalLMOutputWithPast @@ -162,6 +161,32 @@ class GemmaConfig(PretrainedConfig): **kwargs, ) +# Example where we only want to overwrite the defaults of an init? +class GemmaConfig(LlamaConfig): + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + ): + super().__init__(self) class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): @@ -382,7 +407,7 @@ class GemmaModel(LlamaModel): cache_position, ) - +# Example where we ony modify the docstring and call super class GemmaForCausalLM(LlamaForCausalLM): def forward( self, @@ -423,52 +448,18 @@ class GemmaForCausalLM(LlamaForCausalLM): >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + return super().forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + cache_position, ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index a4c6eba595..04380bd944 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -137,7 +137,6 @@ def _get_unpad_data(attention_mask): max_seqlen_in_batch, ) - class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -1180,8 +1179,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) logits = logits.float() + loss = None if labels is not None: # Shift so that tokens < n predict n diff --git a/utils/diff_model_converter.py b/utils/diff_model_converter.py index d5ffcd4dba..4cc943d93f 100644 --- a/utils/diff_model_converter.py +++ b/utils/diff_model_converter.py @@ -236,6 +236,7 @@ def find_classes_in_file(module, old_id="llama", new_id="gemma"): wrapper.visit(class_finder) return class_finder +DOCSTRING_NODE = m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString(value=m.MatchIfTrue(lambda value: re.search(r'\"\"\"[\s\S]*\"\"\"',value) is not None)))]) class SuperTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,) @@ -255,6 +256,9 @@ class SuperTransformer(cst.CSTTransformer): } for stmt in existing_body: if self.python_module.code_for_node(stmt).strip() not in existing_nodes: + if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring: + print("Oh docstring") + continue de_duplicated_new_body.append(stmt) existing_nodes.add(stmt) else: @@ -263,7 +267,11 @@ class SuperTransformer(cst.CSTTransformer): def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: new_body = [] + self.has_docstring = False for expr in node.body: + if m.matches(node.body[0], DOCSTRING_NODE): + self.has_docstring = True + if m.matches( expr, m.SimpleStatementLine( @@ -295,7 +303,8 @@ class SuperTransformer(cst.CSTTransformer): if updated_node.name.value in self.updated_methods: name = updated_node.name.value new_body = self.replace_super_calls(updated_node.body, name) - return updated_node.with_changes(body=new_body) + # dont't change the current func's default params + return updated_node.with_changes(body=new_body, params=updated_node.params) return updated_node def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode: @@ -335,6 +344,9 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, | ``` """ original_node = class_finder.classes[class_name] + + # TODO here is where we merge stuff from super. We can choose to merge the docstring as well! + # We could also check the docstring here original_methods = {f.name.value: f for f in original_node.body.body if m.matches(f, m.FunctionDef())} # Copy methods from original node to replacement node, preserving decorators @@ -343,7 +355,7 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, for name, func in original_methods.items(): if name in updated_methods: # Replace the method in the replacement class, preserving decorators - func = func.with_changes(body=updated_methods[name].body) + func = func.with_changes(body=updated_methods[name].body, params = updated_methods[name].params ) end_meth.append(func) result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))