Add recurrent gemma (#30143)

* Fork.

* RecurrentGemma initial commit.

* Updating __init__.py.

* Minor modification to how we initialize the cache.
Changing how the config specifies the architecture.

* Reformat code to 4 spaces.
Fixed a few typos.

* Fixed the forward pass.
Still unclear on the cache?

* Fixed the RecurrentGemmaForCausalLM

* Minor comment that we might not need attention_mask and output_attention arguments.

* Now cache should work as well.

* Adding a temporary example to check whether the model generation works.

* Adding the tests and updating imports.

* Adding the example file missing in the previous commit.

* First working example.

* Removing .gitignore and reverting parts of __init__.

* Re-add .gitignore.

* Addressing comments for configuration.

* Move mask creation to `_prepare_inputs_for_generation`.

* First try at integration tests:
1. AttributeError: 'GriffinCausalLMOutput' object has no attribute 'attentions'.
2. `cache_position` not passed

* Transfoering between machines.

* Running normal tests.

* Minor fix.

* More fixes.

* Addressing more comments.

* Minor fixes.

* first stab at cleanup

* more refactoring

* fix copies and else

* renaming and get init to work

* fix causal mask creation

* update

* nit

* fix a hell lot of things

* updates

* update conversion script

* make all keys importable

* nits

* add auto mappings

* properly convert ffw_up and down

* add scaling

* fix generations

* for recurrent dtype

* update

* fix going beyong window

* fixup

* add missing files

* current updates to remove last einops

* finish modeling refactor

* TADA

* fix compile

* fix most failing testt ? ?

* update tests

* refactor and update

* update

* nits, fixup and update tests

* more fixup

* nits

* fix imports

* test format

* fixups

* nits

* tuple typing

* fix code quality

* add model card

* fix doc

* skip most generation tests

* nits

* style

* doc fixes

* fix pr and check_copies?

* last nit

* oupsy

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <hi@lysand.re>

* update

* Update src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* update based on review

* doc nit

* fix quality

* quality

* fix slow test model path

* update default dype

* ignore attributes that can be safely ignored in check config attributes

* 0lallalala come on

* save nit

* style

* remove to dict update

* make sure we can also run in float16

* style

---------

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: Aleksandar Botev <botev@google.com>
Co-authored-by: Leonard Berrada <lberrada@users.noreply.github.com>
Co-authored-by: anushanf <anushanf@google.com>
Co-authored-by: botev <botevmg@gmail.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Arthur 2024-04-10 16:59:13 +02:00 committed by GitHub
parent 33bca5419c
commit 0fe44059ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 2001 additions and 1 deletions

View File

@ -476,6 +476,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -472,6 +472,7 @@ Aktuelle Anzahl der Checkpoints: ![](https://img.shields.io/endpoint?url=https:/
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -449,6 +449,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -470,6 +470,7 @@ Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=h
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (de l'équipe Qwen, Alibaba Group) a été publié avec le rapport technique [blog post](https://qwenlm.github.io/blog/qwen-moe/) par Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (de Facebook) a été publié dans l'article [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) par Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (de Google Research) a été publié dans l'article [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) par Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat et Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (de Google) publié dans l'article [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) parthe Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (de Google Research) a été publié dans l'article [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) par Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (de META Platforms) a été publié dans l'article [Designing Network Design Space](https://arxiv.org/abs/2003.13678) par Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (de Google Research) a été publié dans l'article [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) par Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -423,6 +423,7 @@ conda install conda-forge::transformers
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (the Qwen team, Alibaba Group से) Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. द्वाराअनुसंधान पत्र [blog post](https://qwenlm.github.io/blog/qwen-moe/) के साथ जारी किया गया
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (फेसबुक से) साथ में कागज [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) पैट्रिक लुईस, एथन पेरेज़, अलेक्जेंड्रा पिक्टस, फैबियो पेट्रोनी, व्लादिमीर कारपुखिन, नमन गोयल, हेनरिक कुटलर, माइक लुईस, वेन-ताउ यिह, टिम रॉकटाशेल, सेबस्टियन रिडेल, डौवे कीला द्वारा।
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (Google अनुसंधान से) केल्विन गु, केंटन ली, ज़ोरा तुंग, पानुपोंग पसुपत और मिंग-वेई चांग द्वारा साथ में दिया गया पेपर [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909)।
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (Google से) the Griffin, RLHF and Gemma Teams. द्वाराअनुसंधान पत्र [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) के साथ जारी किया गया
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (META रिसर्च से) [Designing Network Design Space](https://arxiv.org/abs/2003.13678) पेपर के साथ जारी किया गया एब्स/2003.13678) इलिजा राडोसावोविक, राज प्रतीक कोसाराजू, रॉस गिर्शिक, कैमिंग ही, पिओटर डॉलर द्वारा।
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (गूगल रिसर्च से) साथ वाला पेपर [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) ह्युंग वोन चुंग, थिबॉल्ट फ़ेवरी, हेनरी त्साई, एम. जॉनसन, सेबेस्टियन रुडर द्वारा।

View File

@ -483,6 +483,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (the Qwen team, Alibaba Group から) Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. から公開された研究論文 [blog post](https://qwenlm.github.io/blog/qwen-moe/)
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (Facebook から) Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela から公開された研究論文: [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401)
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (Google Research から) Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang から公開された研究論文: [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909)
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (Google から) the Griffin, RLHF and Gemma Teams. から公開された研究論文 [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf)
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (Google Research から) Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya から公開された研究論文: [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451)
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (META Platforms から) Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár から公開された研究論文: [Designing Network Design Space](https://arxiv.org/abs/2003.13678)
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (Google Research から) Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder から公開された研究論文: [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821)

View File

@ -398,6 +398,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (the Qwen team, Alibaba Group 에서 제공)은 Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.의 [blog post](https://qwenlm.github.io/blog/qwen-moe/)논문과 함께 발표했습니다.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (Facebook 에서) Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela 의 [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) 논문과 함께 발표했습니다.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (Google Research 에서) Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 의 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 논문과 함께 발표했습니다.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (Google 에서 제공)은 the Griffin, RLHF and Gemma Teams.의 [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf)논문과 함께 발표했습니다.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (Google Research 에서) Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 의 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 논문과 함께 발표했습니다.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (META Research 에서) Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár 의 [Designing Network Design Space](https://arxiv.org/abs/2003.13678) 논문과 함께 발표했습니다.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (Google Research 에서) Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 의 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) 논문과 함께 발표했습니다.

View File

@ -481,6 +481,7 @@ Número atual de pontos de verificação: ![](https://img.shields.io/endpoint?ur
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -471,6 +471,7 @@ conda install conda-forge::transformers
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -473,6 +473,7 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్‌స్టా
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -472,6 +472,7 @@ Số lượng điểm kiểm tra hiện tại: ![](https://img.shields.io/endpoi
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (từ the Qwen team, Alibaba Group) được phát hành với bài báo [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (từ Facebook) được phát hành với bài báo [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (từ Google Research) được phát hành với bài báo [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (từ Google) được phát hành với bài báo [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (từ Google Research) được phát hành với bài báo [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (từ META Platforms) được phát hành với bài báo [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (từ Google Research) được phát hành với bài báo [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -422,6 +422,7 @@ conda install conda-forge::transformers
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (来自 the Qwen team, Alibaba Group) 伴随论文 [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou 发布.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (来自 Facebook) 伴随论文 [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) 由 Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela 发布。
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (来自 Google Research) 伴随论文 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 由 Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 发布。
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (来自 Google) 伴随论文 [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) 由 the Griffin, RLHF and Gemma Teams 发布。
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (来自 Google Research) 伴随论文 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 由 Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 发布。
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Research) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。

View File

@ -434,6 +434,7 @@ conda install conda-forge::transformers
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Research) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.

View File

@ -468,6 +468,8 @@
title: RAG
- local: model_doc/realm
title: REALM
- local: model_doc/recurrent_gemma
title: RecurrentGemma
- local: model_doc/reformer
title: Reformer
- local: model_doc/rembert

View File

@ -243,6 +243,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Qwen2MoE](model_doc/qwen2_moe) | ✅ | ❌ | ❌ |
| [RAG](model_doc/rag) | ✅ | ✅ | ❌ |
| [REALM](model_doc/realm) | ✅ | ❌ | ❌ |
| [RecurrentGemma](model_doc/recurrent_gemma) | ✅ | ❌ | ❌ |
| [Reformer](model_doc/reformer) | ✅ | ❌ | ❌ |
| [RegNet](model_doc/regnet) | ✅ | ✅ | ✅ |
| [RemBERT](model_doc/rembert) | ✅ | ✅ | ❌ |

View File

@ -0,0 +1,48 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# RecurrentGemma
## Overview
The Recurrent Gemma model was proposed in [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams of Google.
The abstract from the paper is the following:
*We introduce RecurrentGemma, an open language model which uses Googles novel Griffin architecture. Griffin combines linear recurrences with local attention to achieve excellent performance on language. It has a fixed-sized state, which reduces memory use and enables efficient inference on long sequences. We provide a pre-trained model with 2B non-embedding parameters, and an instruction tuned variant. Both models achieve comparable performance to Gemma-2B despite being trained on fewer tokens.*
Tips:
- The original checkpoints can be converted using the conversion script `src/transformers/models/recurrent_gemma/convert_recurrent_gemma_weights_to_hf.py`
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ).
## RecurrentGemmaConfig
[[autodoc]] RecurrentGemmaConfig
## RecurrentGemmaModel
[[autodoc]] RecurrentGemmaModel
- forward
## RecurrentGemmaForCausalLM
[[autodoc]] RecurrentGemmaForCausalLM
- forward

View File

@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
Choose one of the following architectures:
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [RecurrentGemma](../model_doc/recurrent_gemma), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)

View File

@ -743,6 +743,7 @@ _import_structure = {
"RealmConfig",
"RealmTokenizer",
],
"models.recurrent_gemma": ["RecurrentGemmaConfig"],
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
"models.regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"],
"models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"],
@ -3115,6 +3116,13 @@ else:
"load_tf_weights_in_realm",
]
)
_import_structure["models.recurrent_gemma"].extend(
[
"RecurrentGemmaForCausalLM",
"RecurrentGemmaModel",
"RecurrentGemmaPreTrainedModel",
]
)
_import_structure["models.reformer"].extend(
[
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -5625,6 +5633,7 @@ if TYPE_CHECKING:
RealmConfig,
RealmTokenizer,
)
from .models.recurrent_gemma import RecurrentGemmaConfig
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .models.regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
@ -7687,6 +7696,11 @@ if TYPE_CHECKING:
RealmScorer,
load_tf_weights_in_realm,
)
from .models.recurrent_gemma import (
RecurrentGemmaForCausalLM,
RecurrentGemmaModel,
RecurrentGemmaPreTrainedModel,
)
from .models.reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,

View File

@ -187,6 +187,7 @@ from . import (
qwen2_moe,
rag,
realm,
recurrent_gemma,
reformer,
regnet,
rembert,

View File

@ -198,6 +198,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("qwen2_moe", "Qwen2MoeConfig"),
("rag", "RagConfig"),
("realm", "RealmConfig"),
("recurrent_gemma", "RecurrentGemmaConfig"),
("reformer", "ReformerConfig"),
("regnet", "RegNetConfig"),
("rembert", "RemBertConfig"),
@ -471,6 +472,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("qwen2_moe", "Qwen2MoE"),
("rag", "RAG"),
("realm", "REALM"),
("recurrent_gemma", "RecurrentGemma"),
("reformer", "Reformer"),
("regnet", "RegNet"),
("rembert", "RemBERT"),

View File

@ -183,6 +183,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("qwen2_moe", "Qwen2MoeModel"),
("recurrent_gemma", "RecurrentGemmaModel"),
("reformer", "ReformerModel"),
("regnet", "RegNetModel"),
("rembert", "RemBertModel"),
@ -469,6 +470,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertLMHeadModel"),
("qwen2", "Qwen2ForCausalLM"),
("qwen2_moe", "Qwen2MoeForCausalLM"),
("recurrent_gemma", "RecurrentGemmaForCausalLM"),
("reformer", "ReformerModelWithLMHead"),
("rembert", "RemBertForCausalLM"),
("roberta", "RobertaForCausalLM"),

View File

@ -363,6 +363,13 @@ else:
),
("rag", ("RagTokenizer", None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
(
"recurrent_gemma",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"reformer",
(

View File

@ -0,0 +1,59 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_recurrent_gemma": ["RecurrentGemmaConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_recurrent_gemma"] = [
"RecurrentGemmaForCausalLM",
"RecurrentGemmaModel",
"RecurrentGemmaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_recurrent_gemma import RecurrentGemmaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_recurrent_gemma import (
RecurrentGemmaForCausalLM,
RecurrentGemmaModel,
RecurrentGemmaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,158 @@
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RecurrentGemma model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class RecurrentGemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RecurrentGemmaModel`]. It is used to instantiate a RecurrentGemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the RecurrentGemma-7B.
e.g. [google/recurrentgemma-2b](https://huggingface.co/google/recurrentgemma-2b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_hidden_layers (`int`, *optional*, defaults to 26):
The number of hidden layers in the model.
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the RecurrentGemma model. Defines the number of
different tokens that can be represented by the
`inputs_ids` passed when calling [`RecurrentGemmaModel`]
hidden_size (`int`, *optional*, defaults to 2560):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 7680):
Dimension of the MLP representations.
num_attention_heads (`int`, *optional*, defaults to 10):
The number of heads for the attention block and the number of
heads/blocks for the block-diagonal layers used in the RG-LRU gates.
This number must divide `hidden_size` and `lru_width`.
lru_width (`int` or `None`, *optional*):
Dimension of the hidden representations of the RG-LRU. If `None`
this will be set to `hidden_size`.
Whether to scale the output of the embeddings by `sqrt(hidden_size)`.
attention_window_size (`int`, *optional*, defaults to 2048):
The size of the attention window used in the attention block.
conv1d_width (`int`, *optional*, defaults to 4):
The kernel size of conv1d layers used in the recurrent blocks.
logits_soft_cap (`float`, *optional*, defaults to 30.0):
The value at which the logits should be soft-capped to after the transformer and LM-head computation in the Causal LM architecture.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether the model should return the last key/values
attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
hidden_activation (``str` or `function``, *optional*, defaults to `"gelu_pytorch_tanh"`):
The hidden activation used in the recurrent block as well as the MLP layer of the decoder layers.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
The partial rotary factor used in the initialization of the rotary embeddings.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
block_types (`List[str]`, *optional*, defaults to `('recurrent', 'recurrent', 'attention')`):
List of aleternating blocks that will be repeated to initialize the `temporal_block` layer.
attention_dropout (`float`, *optional*, defaults to 0.0): dropout value to use after the attention softmax.
num_key_value_heads (`16`, *optional*, defaults to 16): Number of key value heads to use GQA.
attention_bias (`bool`, *optional*, defaults to `False`): whether or not the linear q,k,v of the Attention layer should have bias
w_init_variance_scale (`float`, *optional*, defaults to 0.01): weight initialization variance.
```python
>>> from transformers import RecurrentGemmaModel, RecurrentGemmaConfig
>>> # Initializing a RecurrentGemma recurrentgemma-2b style configuration
>>> configuration = RecurrentGemmaConfig()
>>> # Initializing a model from the recurrentgemma-2b style configuration
>>> model = RecurrentGemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "recurrent_gemma"
def __init__(
self,
num_hidden_layers=26,
vocab_size=256000,
hidden_size=2560,
intermediate_size=3 * 2560,
num_attention_heads=10,
lru_width=None,
attention_window_size=2048,
conv1d_width=4,
logits_soft_cap=30.0,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
hidden_activation="gelu_pytorch_tanh",
partial_rotary_factor=0.5,
rope_theta=10000.0,
block_types=("recurrent", "recurrent", "attention"),
attention_dropout=0.0,
num_key_value_heads=None,
attention_bias=False,
w_init_variance_scale=0.01,
**kwargs,
):
self.num_hidden_layers = num_hidden_layers
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.attention_window_size = attention_window_size
self.conv1d_width = conv1d_width
self.logits_soft_cap = logits_soft_cap
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.partial_rotary_factor = partial_rotary_factor
self.block_types = list(block_types)
self.hidden_activation = hidden_activation
self.head_dim = self.hidden_size // self.num_attention_heads
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
if self.num_key_value_heads > self.num_attention_heads:
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
self.attention_dropout = attention_dropout
self.attention_bias = attention_bias
self.w_init_variance_scale = w_init_variance_scale
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
@property
def layers_block_type(self):
return (self.block_types * 100)[: self.num_hidden_layers]

View File

@ -0,0 +1,222 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import warnings
import torch
from accelerate import init_empty_weights
from transformers import GemmaTokenizer, RecurrentGemmaConfig, RecurrentGemmaForCausalLM
try:
from transformers import GemmaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
GemmaTokenizerFast = None
import regex as re
"""
Sample usage:
```
python src/transformers/models/gemma/convert_gemma_weights_to_hf.py \
--input_dir /path/to/downloaded/gemma/weights --model_size 7B --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import GemmaForCausalLM, GemmaTokenizerFast
model = GemmaForCausalLM.from_pretrained("/output/path")
tokenizer = GemmaTokenizerFast.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
gemma_2b_config = RecurrentGemmaConfig(
num_attention_heads=10,
num_key_value_heads=1,
hidden_size=2560,
intermediate_size=15360,
vocab_size=256000,
num_hidden_layers=26,
)
gemma_7b_config = RecurrentGemmaConfig()
CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config}
LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"}
def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False, dtype=torch.float32):
print(f"Fetching all parameters from the checkpoint at '{input_base_path}'")
model_state_dict = torch.load(input_base_path, map_location="cpu")
REPLACEMENT = {
"blocks.": "layers.",
".ffw_down.b": ".down_proj.b",
".ffw_down.w": ".down_proj.w",
".ffw_up.b": ".up_proj.bias",
".ffw_up.w": ".up_proj.weight",
"recurrent_block": "temporal_block",
"attention_block": "temporal_block",
"temporal_block.proj_final": "temporal_block.out_proj",
"norm.scale": "norm.weight",
".proj_k": ".k_proj",
".proj_q": ".q_proj",
".proj_v": ".v_proj",
".proj_final": ".o_proj",
"embedder.input_embedding": "embed_tokens.weight",
"conv_1d.w": "conv_1d.weight",
"conv_1d.b": "conv_1d.bias",
"input_gate.w": "input_gate.weight",
"input_gate.b": "input_gate.bias",
"a_param": "recurrent_param",
"a_gate.b": "recurrent_gate.bias",
"a_gate.w": "recurrent_gate.weight",
}
state_dict = {}
for k, v in model_state_dict.items():
k = "model." + k
pattern = re.compile("|".join(map(re.escape, REPLACEMENT.keys())))
key = pattern.sub(lambda match: REPLACEMENT[match.group(0)], k)
if "conv_1d.weight" in key:
v = v[:, None, :].transpose(0, 2)
if "up_proj.weight" in key:
state_dict[key.replace("up_proj", "gate_proj")] = v[0].T.contiguous()
v = v[1].T.contiguous()
if "up_proj.bias" in key:
state_dict[key.replace("up_proj", "gate_proj")] = v[0, 0, 0].clone()
v = v[1, 0, 0].contiguous()
if "recurrent_gate.bias" in key:
state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone()
elif "recurrent_gate.weight" in key:
state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone()
elif "input_gate.b" in key:
state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone()
elif "input_gate.w" in key:
state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone()
elif "embed_tokens" in key:
state_dict[key] = v[: config.vocab_size, :].contiguous().clone()
state_dict["lm_head.weight"] = v[: config.vocab_size, :].contiguous().clone()
else:
state_dict[key] = v.contiguous()
torch.set_default_dtype(dtype)
print("Loading the checkpoint in a Gemma model.")
with init_empty_weights():
model = RecurrentGemmaForCausalLM(config)
model.load_state_dict(state_dict, assign=True, strict=True)
model.config.torch_dtype = torch.float32
del model.config._name_or_path
print("Saving in the Transformers format.")
if push_to_hub:
print(f"pushing the model to {save_path}")
else:
model.save_pretrained(save_path, safe_serialization=safe_serialization)
def write_tokenizer(input_tokenizer_path, save_path, push_to_hub=False):
# Initialize the tokenizer based on the `spm` model
tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast
print(f"Saving a {tokenizer_class.__name__} to {save_path}.")
tokenizer = tokenizer_class(input_tokenizer_path)
if push_to_hub:
tokenizer.push_to_hub(save_path)
else:
tokenizer.save_pretrained(save_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_checkpoint",
help="Absolute path to the target Gemma weights.",
default="/home/arthur/transformers_recurrentgemma/google/recurrent-gemma-2b-it/ToBeDeleted/2b-it.pt",
)
parser.add_argument(
"--tokenizer_checkpoint",
help="Location of Gemma tokenizer model",
)
parser.add_argument(
"--model_size",
default="2B",
choices=["2B", "7B", "tokenizer_only"],
help="'f' models correspond to the finetuned versions, and are specific to the Gemma2 official release. For more details on Gemma2, checkout the original repo: https://huggingface.co/google/gemma-7b",
)
parser.add_argument(
"--output_dir",
default="google/recurrent-gemma-2b-it-hf",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--pickle_serialization",
help="Whether or not to save using `safetensors`.",
action="store_true",
default=False,
)
parser.add_argument(
"--convert_tokenizer",
help="Whether or not to convert the tokenizer as well.",
action="store_true",
default=False,
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
action="store_true",
default=False,
)
parser.add_argument(
"--dtype",
default="float32",
help="Target dtype of the converted model",
)
args = parser.parse_args()
if args.convert_tokenizer:
if args.tokenizer_checkpoint is None:
raise ValueError("Path to the tokenizer is required when passing --convert_tokenizer")
spm_path = os.path.join(args.tokenizer_checkpoint)
write_tokenizer(spm_path, args.output_dir, args.push_to_hub)
config = CONFIG_MAPPING[args.model_size]
dtype = getattr(torch, args.dtype)
write_model(
config=config,
input_base_path=args.input_checkpoint,
save_path=args.output_dir,
safe_serialization=not args.pickle_serialization,
push_to_hub=args.push_to_hub,
dtype=dtype,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,938 @@
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch RecurrentGemma model."""
import math
from typing import Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_recurrent_gemma import RecurrentGemmaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RecurrentGemmaConfig"
_MAX_SQRT_GRADIENT = 1000.0
# Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->RecurrentGemma
class RecurrentGemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst RecurrentGemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
ALL_LAYERNORM_LAYERS.append(RecurrentGemmaRMSNorm)
class RecurrentGemmaRotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000, device=None):
super().__init__()
self.dim = dim
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
@torch.no_grad()
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class RecurrentGemmaSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: RecurrentGemmaConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.partial_rotary_factor = config.partial_rotary_factor
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
self.rotary_emb = RecurrentGemmaRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
base=config.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
# Partial rotary embedding
query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1)
key_rot, key_pass = torch.chunk(key_states, int(1 / self.partial_rotary_factor), dim=-1)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
if use_cache and hasattr(self, "key_states"):
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_mask=causal_mask, # pretty much a must for sliding window backend!
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
def _setup_cache(self, batch_size, device, dtype=None):
if dtype is None and self.config.torch_dtype is not None:
dtype = self.config.torch_dtype
dtype = dtype if dtype is not None else torch.float32
cache_shape = (batch_size, self.num_key_value_heads, self.config.attention_window_size, self.head_dim)
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
@torch.no_grad()
def _update_cache(self, key_states, value_states, **cache_kwargs):
"""
torch.compile compatible sliding window.
Computes the `indices` based on `cache_position >= self.config.attention_window_size - 1`.
The `to_shift` is only true once we are above attention_window_size. Thus with `attention_window_size==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.attention_window_size
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `attention_window_size`)
"""
cache_position = cache_kwargs.get("cache_position")
if cache_position.shape[0] > self.config.attention_window_size:
# int indexing -> device sync? in compile, use tensor
k_out = key_states[:, :, -self.config.attention_window_size :, :]
v_out = value_states[:, :, -self.config.attention_window_size :, :]
else:
slicing = torch.ones(
self.config.attention_window_size, dtype=torch.long, device=value_states.device
).cumsum(0)
cache_position = cache_position.clamp(0, self.config.attention_window_size - 1)
to_shift = cache_position >= self.config.attention_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size
k_out, v_out = self.key_states, self.value_states
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
class SqrtBoundDerivative(torch.autograd.Function):
"""Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`."""
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
"""The forward pass, which is a normal `sqrt`."""
ctx.save_for_backward(x)
return torch.sqrt(x)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""The backward pass, which clips the `sqrt` gradient."""
(x,) = ctx.saved_tensors
clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2))
return grad_output / torch.sqrt(clipped_x_times_4)
class RecurrentGemmaRglru(nn.Module):
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.block_width = config.lru_width // self.num_attention_heads
self.recurrent_param = nn.Parameter(torch.empty([config.lru_width]))
self.input_gate_weight = nn.Parameter(
torch.empty([self.num_attention_heads, self.block_width, self.block_width])
)
self.input_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width]))
self.recurrent_gate_weight = nn.Parameter(
torch.empty([self.num_attention_heads, self.block_width, self.block_width])
)
self.recurrent_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width]))
self.recurrent_states = None
def forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, lru_width = activations.shape
reset = position_ids[:, :, None] == 0
reshape_act = activations.reshape(batch_size * seq_len, self.num_attention_heads, self.block_width)
reshape_act = reshape_act.permute(1, 0, 2)
res = torch.baddbmm(self.input_gate_bias[:, None, :], reshape_act, self.input_gate_weight)
input_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width))
res = torch.baddbmm(self.recurrent_gate_bias[:, None, :], reshape_act, self.recurrent_gate_weight)
recurrent_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width))
# Compute the parameter `A` of the recurrence.
log_recurrent_gate = -8.0 * recurrent_gate * nn.functional.softplus(self.recurrent_param)
recurrent_gate = torch.exp(log_recurrent_gate)
a_square = torch.exp(2 * log_recurrent_gate)
# Gate the input.
gated_inputs = activations * input_gate
# Apply gamma normalization to the input. We need to clip the derivatives of
# `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying
multiplier = 1
tracing = isinstance(activations, torch.fx.Proxy) or (
hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()
)
if not torch.jit.is_tracing() and not tracing:
multiplier = SqrtBoundDerivative.apply(1 - a_square)
multiplier = reset + ~reset * multiplier
normalized_x = gated_inputs * multiplier.type(activations.dtype)
hidden_states, recurrent_states = self._rnn_scan(
hidden_states=normalized_x,
recurrent_gate=recurrent_gate,
reset=reset,
recurrent_states=self.recurrent_states,
)
self.recurrent_states = recurrent_states
return hidden_states
# TODO refactor
def _rnn_scan(
self,
hidden_states: torch.Tensor,
recurrent_gate: torch.Tensor,
reset: torch.Tensor,
recurrent_states: Union[torch.Tensor, None],
acc_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Runs the recurrence of a linear RNN.
Args:
hidden_states: The input sequence.
recurrent_gate: The diagonal of the recurrence matrix `A`.
reset: Indicator of document boundaries, e.g. when to reset the hidden state
of the RNN.
recurrent_states: The initial hidden state.
acc_dtype: The data type for the accumulation.
Returns:
The output of the linear recurrence.
"""
# Multiply `a` by the reset.
recurrent_gate = recurrent_gate * ~reset
if hidden_states.shape[1] == 1:
# Using scan in sampling mode.
if recurrent_states is None: # same here, when decoding you always have cache
return hidden_states, hidden_states[:, 0].type(acc_dtype)
else:
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None]
contextualized_states += hidden_states.type(acc_dtype)
return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]
else:
# Using scan in linear mode.
if recurrent_states is None:
recurrent_states = torch.zeros(hidden_states[:, 0].shape, dtype=acc_dtype, device=hidden_states.device)
contextualized_states = torch.zeros_like(hidden_states)
for t in range(hidden_states.shape[1]):
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states
recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)
return contextualized_states, recurrent_states
class RecurrentGemmaRecurrentBlock(nn.Module):
"""Griffin and Hawk's recurrent block."""
def __init__(self, config):
super().__init__()
self.lru_width = config.lru_width
self.hidden_size = config.hidden_size
self.linear_y = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width)
self.linear_x = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width)
self.linear_out = nn.Linear(in_features=config.lru_width, out_features=config.hidden_size)
self.conv1d_width = config.conv1d_width
self.conv_1d = nn.Conv1d(
config.lru_width,
config.lru_width,
kernel_size=config.conv1d_width,
groups=config.lru_width,
padding=config.conv1d_width - 1,
)
self.rg_lru = RecurrentGemmaRglru(config)
self.act_fn = ACT2FN[config.hidden_activation]
self.conv1d_state = None
def forward(
self,
input_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
cache_position: torch.Tensor,
use_cache: bool = True,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
_, seq_len, _ = input_states.shape
y_branch = self.linear_y(input_states)
y_branch = self.act_fn(y_branch)
x_branch = self.linear_x(input_states)
x_branch = x_branch.transpose(1, 2)
if use_cache:
if cache_position.shape[0] != 1: # prefill
self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0))
x_branch = self.conv_1d(x_branch)[..., :seq_len]
else: # decoding
conv_state = torch.cat((self.conv1d_state, x_branch), -1)
x_branch = torch.sum(conv_state * self.conv_1d.weight[:, 0, :], dim=-1) + self.conv_1d.bias
x_branch = x_branch.unsqueeze(-1)
self.conv1d_state = conv_state[:, :, 1:]
else:
x_branch = self.conv_1d(x_branch)[..., :seq_len]
x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
hidden_states = x_branch * y_branch
hidden_states = self.linear_out(hidden_states)
return hidden_states
def _setup_cache(self, batch, device, dtype):
# recurrent_states always computed in full precision
self.rg_lru.recurrent_states = torch.zeros((batch, self.lru_width), device=device, dtype=torch.float32)
self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype)
TEMPORAL_BLOCK_CLASSES = {"recurrent": RecurrentGemmaRecurrentBlock, "attention": RecurrentGemmaSdpaAttention}
class RecurrentGemmaMlp(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, hidden_states):
gate = self.act_fn(self.gate_proj(hidden_states))
return self.down_proj(gate * self.up_proj(hidden_states))
class RecurrentGemmaDecoderLayer(nn.Module):
"""Griffin and Hawk's residual block."""
def __init__(self, config, layer_idx):
super().__init__()
self.temporal_pre_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.temporal_block = TEMPORAL_BLOCK_CLASSES[config.layers_block_type[layer_idx]](config)
self.channel_pre_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp_block = RecurrentGemmaMlp(config)
def forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
cache_position: torch.Tensor = None,
use_cache: bool = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
raw_activations = activations
inputs_normalized = self.temporal_pre_norm(raw_activations) # RMSNorm introduces slight slight differences
hidden_states = self.temporal_block(
inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache
)
residual = hidden_states + raw_activations
hidden_states = self.channel_pre_norm(residual)
hidden_states = self.mlp_block(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
RECURRENTGEMMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`RecurrentGemmaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare RecurrentGemma Model outputting raw hidden-states without any specific head on top.",
RECURRENTGEMMA_START_DOCSTRING,
)
class RecurrentGemmaPreTrainedModel(PreTrainedModel):
config_class = RecurrentGemmaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["RecurrentGemmaDecoderLayer"]
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
def _init_weights(self, module):
std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)
if isinstance(module, nn.Conv1d):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
torch.nn.init.zeros_(module.bias)
elif isinstance(module, RecurrentGemmaSdpaAttention):
torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
std = math.sqrt(self.config.final_w_init_variance_scale / self.config.hidden_size)
torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=std)
elif isinstance(module, RecurrentGemmaRecurrentBlock):
torch.nn.init.zeros_(module.linear_x.bias)
torch.nn.init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
torch.nn.init.zeros_(module.linear_y.bias)
torch.nn.init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
std = math.sqrt(self.config.final_w_init_variance_scale / self.config.lru_width)
torch.nn.init.normal_(module.linear_out.weight, mean=0.0, std=std)
torch.nn.init.zeros_(module.linear_out.bias)
elif isinstance(module, RecurrentGemmaRglru):
std = math.sqrt(
self.config.w_init_variance_scale / (self.config.lru_width // self.config.num_attention_heads)
)
torch.nn.init.normal_(module.input_gate_weight, mean=0.0, std=std)
torch.nn.init.normal_(module.recurrent_gate_weight, mean=0.0, std=std)
torch.nn.init.zeros_(module.input_gate_bias)
torch.nn.init.zeros_(module.recurrent_gate_bias)
module.recurrent_param.data.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8)
module.recurrent_param.data.log_().mul_(0.5)
module.recurrent_param.data.neg_().exp_().sub_(1.0).log_()
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if getattr(module, "bias", None) is not None:
torch.nn.init.zeros_(module.bias)
def _setup_cache(self, config, batch, device, dtype):
layers = getattr(self, "model", self).layers
for layer in layers:
layer.temporal_block._setup_cache(batch, device, dtype)
def reset_cache(self, batch, device, dtype):
pass
RECURRENTGEMMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare RecurrentGemma Model outputting raw hidden-states without any specific head on top.",
RECURRENTGEMMA_START_DOCSTRING,
)
class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`RecurrentGemmaDecoderLayer`]
Args:
config: RecurrentGemmaConfig
"""
def __init__(self, config: RecurrentGemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[RecurrentGemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.register_buffer("normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16))
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
def get_input_embeddings(self):
return self.embed_tokens
# Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(RECURRENTGEMMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
if use_cache and inputs_embeds.shape[1] != 1: # TODO let's maybe only call in the `generate`?
self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
if cache_position is None:
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
hidden_states = hidden_states * self.normalizer.type(hidden_states.dtype)
all_hidden_states = () if output_hidden_states else None
for i, residual_block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
residual_block.__call__, hidden_states, position_ids, causal_mask, cache_position, use_cache
)
else:
hidden_states = residual_block(hidden_states, position_ids, causal_mask, cache_position, use_cache)
hidden_states = self.final_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# Ignore copy
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
target_length = max(self.config.attention_window_size, sequence_length)
diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
causal_mask = diagonal
if sequence_length != 1:
causal_mask = torch.triu(diagonal, diagonal=-1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if attention_mask is not None and attention_mask.device.type == "cuda":
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = RecurrentGemmaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Ignore copy
@add_start_docstrings_to_model_forward(RECURRENTGEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, CausalLMOutput]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, RecurrentGemmaForCausalLM
>>> model = RecurrentGemmaForCausalLM.from_pretrained("google/recurrentgemma-2b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-2b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True
outputs = self.model(
input_ids=input_ids,
cache_position=cache_position,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# Soft-cap the logits TODO remove if always done.
# if self.config.logits_soft_cap is not None:
cap = self.config.logits_soft_cap
logits = nn.functional.tanh(logits / cap) * cap
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
# Ignore copy
def prepare_inputs_for_generation(
self, input_ids, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=None, **kwargs
):
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
attention_mask = attention_mask[:, -self.config.attention_window_size :]
past_length = cache_position[0]
if past_length > 0:
position_ids = position_ids[:, past_length:]
if inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]}
else:
model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()}
if cache_position is not None:
cache_position = cache_position[-position_ids.shape[1] :]
model_inputs.update(
{
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_position": cache_position,
"use_cache": use_cache,
}
)
return model_inputs
# Ignore copy
def _reorder_cache(self, past_key_values, beam_idx):
for layer in self.layers:
if hasattr(layer.temporal_block, "key_states"):
k_state = layer.temporal_block.key_states
v_state = layer.temporal_block.value_states
k_state = k_state.index_select(0, beam_idx.to(k_state.device))
v_state = v_state.index_select(0, beam_idx.to(v_state.device))
return None

View File

@ -7051,6 +7051,27 @@ def load_tf_weights_in_realm(*args, **kwargs):
requires_backends(load_tf_weights_in_realm, ["torch"])
class RecurrentGemmaForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RecurrentGemmaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RecurrentGemmaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

View File

@ -0,0 +1,508 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch RecurrentGemma model. """
import unittest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_read_token,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import RecurrentGemmaForCausalLM, RecurrentGemmaModel
class RecurrentGemmaModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=12,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
num_hidden_layers=3,
vocab_size=99,
hidden_size=32,
intermediate_size=3 * 32,
num_attention_heads=2,
lru_width=2 * 32,
embeddings_scale_by_sqrt_dim=True,
attention_window_size=16,
conv1d_width=4,
logits_soft_cap=30.0,
rms_norm_eps=1e-6,
use_cache=True,
rope_theta=10000.0,
type_vocab_size=16,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.num_hidden_layers = num_hidden_layers
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim
self.attention_window_size = attention_window_size
self.conv1d_width = conv1d_width
self.logits_soft_cap = logits_soft_cap
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return RecurrentGemmaConfig(
num_hidden_layers=self.num_hidden_layers,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_attention_heads=self.num_attention_heads,
lru_width=self.lru_width,
embeddings_scale_by_sqrt_dim=self.embeddings_scale_by_sqrt_dim,
attention_window_size=self.attention_window_size,
conv1d_width=self.conv1d_width,
logits_soft_cap=self.logits_soft_cap,
rms_norm_eps=self.rms_norm_eps,
use_cache=self.use_cache,
rope_theta=self.rope_theta,
pad_token_id=self.pad_token_id,
output_attentions=False,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->RecurrentGemma
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = RecurrentGemmaModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->RecurrentGemma
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = RecurrentGemmaModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->RecurrentGemma
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = RecurrentGemmaForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->RecurrentGemma
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = RecurrentGemmaForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->RecurrentGemma
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else ()
# all_generative_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else () #TODO @gante not fully supported
pipeline_model_mapping = (
{
"feature-extraction": RecurrentGemmaModel,
"text-generation": RecurrentGemmaForCausalLM,
}
if is_torch_available()
else {}
)
fx_compatible = False # FIXME let's try to support this @ArthurZucker
test_torchscript = False # FIXME let's try to support this @ArthurZucker
test_missing_keys = False
test_model_parallel = False
test_pruning = False
test_head_masking = False # RecurrentGemma does not have attention heads
test_model_parallel = False
# Need to remove 0.9 in `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.6]
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
):
return True
def setUp(self):
# We don't output attentions
self.has_attentions = False
self.model_tester = RecurrentGemmaModelTester(self)
self.config_tester = ConfigTester(self, config_class=RecurrentGemmaConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip("Recurrent gemma does not use legacy cache")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip("RecurrentGemma does not return pkv")
def test_past_key_values_format(self):
pass
@unittest.skip("RecurrentGemma only supports sdpa")
def test_eager_matches_sdpa_generate(self):
pass
@unittest.skip("RecurrentGemma only supports sdpa")
def test_eager_matches_sdpa_inference(self):
pass
@unittest.skip("RecurrentGemma does not return the cache")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("RecurrentGemma does not return the cache")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("RecurrentGemma does not return the cache")
def test_contrastive_generate(self):
pass
@unittest.skip("SQRBound is known to have issues with gc")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def _check_attentions_for_generate(self, *args, **kwargs):
return True # Model does not return attention
@unittest.skip("Past key values are not returned")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip("Past key values are not returned")
def test_model_parallelism(self):
pass
@unittest.skip("Past key values are not returned")
def test_model_parallel_beam_search(self):
pass
def _check_past_key_values_for_generate(self, *args, **kwargs):
return True
@unittest.skip("Rely on `past_key_values` to crop the assistant pkv. Not supported")
def test_assisted_decoding_matches_greedy_search(self):
pass
@unittest.skip("RecurrentGemma's output different if you pad left or right. This is expected")
def test_left_padding_compatibility(self):
pass
@unittest.skip("Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
def test_assisted_decoding_sample(self):
pass
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx if not use_cache else 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@unittest.skip("TODO @arthurzucker not super important and failing.")
def test_initialization(self):
pass
@require_torch_gpu
@slow
class RecurrentGemmaIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
model_id = "google/recurrentgemma-2b"
@require_read_token
def test_2b_generate(self):
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a very good day for you. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do.'] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(self.model_id, low_cpu_mem_usage=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer.padding_side = "right"
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
tokenizer.padding_side = "left"
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I am going to share with you the best <strong><em>free online video editing software</em></strong>.\n\n<h2><strong>Best Free Online Video Editing Software</strong></h2>\n\n<strong>1.</strong> <strong>Wondershare Filmora</strong>\n\nWondershare Filmora is a free online video editing software that is used to edit videos.'] # fmt: skip
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
del model
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
model = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
del model
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
def test_2b_sample(self):
set_seed(0)
EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write <em>P</em> if the underlined word group is a phrase and <em>NP</em> if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
inputs = tokenizer("Where is Paris ?", return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=128, do_sample=True)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXT)
@require_bitsandbytes
@require_read_token
def test_model_2b_8bit(self):
EXPECTED_TEXTS = ['<bos>Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "<bos>Hi today<pad><pad> I'm going to show you how to make a simple and easy to use <strong><em><u>"] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(
"gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
def test_long_context(self):
input_text = [
'<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'
]
EXPECTED_GENERATION = [
' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." "We are not aware of any video footage that could have been taken on board the plane," Delannoy said. "We are not aware of any video footage that could'
]
model = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_GENERATION)

View File

@ -34,6 +34,8 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used to compute the property `self.layers_block_type`
"RecurrentGemmaConfig": ["block_types"],
# used as in the config to define `intermediate_size`
"MambaConfig": ["expand"],
# used as `self.bert_model = BertModel(config, ...)`

View File

@ -86,6 +86,7 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"RecurrentGemmaModel", # Building part of bigger (tested) model.
"FuyuForCausalLM", # Not tested fort now
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
"UMT5EncoderModel", # Building part of bigger (tested) model.

View File

@ -768,6 +768,7 @@ src/transformers/models/rag/modeling_tf_rag.py
src/transformers/models/rag/retrieval_rag.py
src/transformers/models/realm/modeling_realm.py
src/transformers/models/realm/retrieval_realm.py
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
src/transformers/models/regnet/configuration_regnet.py
src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py