update
This commit is contained in:
parent
df19157424
commit
f0068b702c
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
LlamaForTokenClassification,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
@ -406,3 +407,6 @@ class GemmaForCausalLM(LlamaForCausalLM):
|
|||
|
||||
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
class GemmaForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
|
@ -30,6 +30,7 @@ from ...modeling_outputs import (
|
|||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
|
@ -275,6 +276,7 @@ class GemmaAttention(nn.Module):
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
|
@ -286,8 +288,8 @@ class GemmaAttention(nn.Module):
|
|||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
|
@ -349,7 +351,6 @@ class GemmaFlashAttention2(GemmaAttention):
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
|
@ -812,6 +813,7 @@ _CONFIG_FOR_DOC = "GemmaConfig"
|
|||
GEMMA_START_DOCSTRING,
|
||||
)
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
@ -1031,6 +1033,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
|
||||
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
|
@ -1365,7 +1368,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
|||
""",
|
||||
GEMMA_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
|
Loading…
Reference in New Issue