diff --git a/README.md b/README.md index 8518fc09dc..bb06321f12 100644 --- a/README.md +++ b/README.md @@ -476,6 +476,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/README_de.md b/README_de.md index 66fae98706..3eb2c63b03 100644 --- a/README_de.md +++ b/README_de.md @@ -472,6 +472,7 @@ Aktuelle Anzahl der Checkpoints: ![](https://img.shields.io/endpoint?url=https:/ 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/README_es.md b/README_es.md index e4f4dedc3e..202411bbf8 100644 --- a/README_es.md +++ b/README_es.md @@ -449,6 +449,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/README_fr.md b/README_fr.md index c2da27829e..7d6d2e765a 100644 --- a/README_fr.md +++ b/README_fr.md @@ -470,6 +470,7 @@ Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=h 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (de l'équipe Qwen, Alibaba Group) a été publié avec le rapport technique [blog post](https://qwenlm.github.io/blog/qwen-moe/) par Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (de Facebook) a été publié dans l'article [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) par Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (de Google Research) a été publié dans l'article [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) par Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat et Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (de Google) publié dans l'article [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) parthe Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (de Google Research) a été publié dans l'article [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) par Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (de META Platforms) a été publié dans l'article [Designing Network Design Space](https://arxiv.org/abs/2003.13678) par Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (de Google Research) a été publié dans l'article [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) par Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/README_hd.md b/README_hd.md index a5bd56ee1c..33ae62baf0 100644 --- a/README_hd.md +++ b/README_hd.md @@ -423,6 +423,7 @@ conda install conda-forge::transformers 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (the Qwen team, Alibaba Group से) Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. द्वाराअनुसंधान पत्र [blog post](https://qwenlm.github.io/blog/qwen-moe/) के साथ जारी किया गया 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (फेसबुक से) साथ में कागज [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) पैट्रिक लुईस, एथन पेरेज़, अलेक्जेंड्रा पिक्टस, फैबियो पेट्रोनी, व्लादिमीर कारपुखिन, नमन गोयल, हेनरिक कुटलर, माइक लुईस, वेन-ताउ यिह, टिम रॉकटाशेल, सेबस्टियन रिडेल, डौवे कीला द्वारा। 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (Google अनुसंधान से) केल्विन गु, केंटन ली, ज़ोरा तुंग, पानुपोंग पसुपत और मिंग-वेई चांग द्वारा साथ में दिया गया पेपर [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909)। +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (Google से) the Griffin, RLHF and Gemma Teams. द्वाराअनुसंधान पत्र [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) के साथ जारी किया गया 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (META रिसर्च से) [Designing Network Design Space](https://arxiv.org/abs/2003.13678) पेपर के साथ जारी किया गया एब्स/2003.13678) इलिजा राडोसावोविक, राज प्रतीक कोसाराजू, रॉस गिर्शिक, कैमिंग ही, पिओटर डॉलर द्वारा। 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (गूगल रिसर्च से) साथ वाला पेपर [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) ह्युंग वोन चुंग, थिबॉल्ट फ़ेवरी, हेनरी त्साई, एम. जॉनसन, सेबेस्टियन रुडर द्वारा। diff --git a/README_ja.md b/README_ja.md index e42a5680a7..dbe2cecdbd 100644 --- a/README_ja.md +++ b/README_ja.md @@ -483,6 +483,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (the Qwen team, Alibaba Group から) Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. から公開された研究論文 [blog post](https://qwenlm.github.io/blog/qwen-moe/) 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (Facebook から) Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela から公開された研究論文: [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (Google Research から) Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang から公開された研究論文: [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (Google から) the Griffin, RLHF and Gemma Teams. から公開された研究論文 [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (Google Research から) Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya から公開された研究論文: [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (META Platforms から) Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár から公開された研究論文: [Designing Network Design Space](https://arxiv.org/abs/2003.13678) 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (Google Research から) Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder から公開された研究論文: [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) diff --git a/README_ko.md b/README_ko.md index 95cb1b0b79..547572c4f2 100644 --- a/README_ko.md +++ b/README_ko.md @@ -398,6 +398,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (the Qwen team, Alibaba Group 에서 제공)은 Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.의 [blog post](https://qwenlm.github.io/blog/qwen-moe/)논문과 함께 발표했습니다. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (Facebook 에서) Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela 의 [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) 논문과 함께 발표했습니다. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (Google Research 에서) Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 의 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 논문과 함께 발표했습니다. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (Google 에서 제공)은 the Griffin, RLHF and Gemma Teams.의 [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf)논문과 함께 발표했습니다. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (Google Research 에서) Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 의 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 논문과 함께 발표했습니다. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (META Research 에서) Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár 의 [Designing Network Design Space](https://arxiv.org/abs/2003.13678) 논문과 함께 발표했습니다. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (Google Research 에서) Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 의 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) 논문과 함께 발표했습니다. diff --git a/README_pt-br.md b/README_pt-br.md index 7d10ce5e8c..ad3e6bb90d 100644 --- a/README_pt-br.md +++ b/README_pt-br.md @@ -481,6 +481,7 @@ Número atual de pontos de verificação: ![](https://img.shields.io/endpoint?ur 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/README_ru.md b/README_ru.md index 03cc1b919e..7c5cb5f179 100644 --- a/README_ru.md +++ b/README_ru.md @@ -471,6 +471,7 @@ conda install conda-forge::transformers 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/README_te.md b/README_te.md index fa762d5659..90a443e434 100644 --- a/README_te.md +++ b/README_te.md @@ -473,6 +473,7 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్‌స్టా 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/README_vi.md b/README_vi.md index e7cad3b364..98114bf0a8 100644 --- a/README_vi.md +++ b/README_vi.md @@ -472,6 +472,7 @@ Số lượng điểm kiểm tra hiện tại: ![](https://img.shields.io/endpoi 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (từ the Qwen team, Alibaba Group) được phát hành với bài báo [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (từ Facebook) được phát hành với bài báo [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (từ Google Research) được phát hành với bài báo [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (từ Google) được phát hành với bài báo [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (từ Google Research) được phát hành với bài báo [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (từ META Platforms) được phát hành với bài báo [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (từ Google Research) được phát hành với bài báo [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/README_zh-hans.md b/README_zh-hans.md index 8ac4d5c388..bd7457b25f 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -422,6 +422,7 @@ conda install conda-forge::transformers 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (来自 the Qwen team, Alibaba Group) 伴随论文 [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou 发布. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (来自 Facebook) 伴随论文 [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) 由 Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela 发布。 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (来自 Google Research) 伴随论文 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 由 Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 发布。 +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (来自 Google) 伴随论文 [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) 由 the Griffin, RLHF and Gemma Teams 发布。 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (来自 Google Research) 伴随论文 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 由 Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 发布。 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Research) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 0e31d8e6b5..1eb46fe700 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -434,6 +434,7 @@ conda install conda-forge::transformers 1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen-moe/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou. 1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. +1. **[RecurrentGemma](https://huggingface.co/docs/transformers/main/model_doc/recurrent-gemma)** (from Google) released with the paper [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RegNet](https://huggingface.co/docs/transformers/model_doc/regnet)** (from META Research) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index af44de4d10..7daf91c99d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -468,6 +468,8 @@ title: RAG - local: model_doc/realm title: REALM + - local: model_doc/recurrent_gemma + title: RecurrentGemma - local: model_doc/reformer title: Reformer - local: model_doc/rembert diff --git a/docs/source/en/index.md b/docs/source/en/index.md index ffa9ae3f4b..9c5c87d00f 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -243,6 +243,7 @@ Flax), PyTorch, and/or TensorFlow. | [Qwen2MoE](model_doc/qwen2_moe) | ✅ | ❌ | ❌ | | [RAG](model_doc/rag) | ✅ | ✅ | ❌ | | [REALM](model_doc/realm) | ✅ | ❌ | ❌ | +| [RecurrentGemma](model_doc/recurrent_gemma) | ✅ | ❌ | ❌ | | [Reformer](model_doc/reformer) | ✅ | ❌ | ❌ | | [RegNet](model_doc/regnet) | ✅ | ✅ | ✅ | | [RemBERT](model_doc/rembert) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/recurrent_gemma.md b/docs/source/en/model_doc/recurrent_gemma.md new file mode 100644 index 0000000000..35a8ce9e3a --- /dev/null +++ b/docs/source/en/model_doc/recurrent_gemma.md @@ -0,0 +1,48 @@ + + +# RecurrentGemma + +## Overview + +The Recurrent Gemma model was proposed in [RecurrentGemma: Moving Past Transformers for Efficient Open Language Models](https://storage.googleapis.com/deepmind-media/gemma/recurrentgemma-report.pdf) by the Griffin, RLHF and Gemma Teams of Google. + +The abstract from the paper is the following: + +*We introduce RecurrentGemma, an open language model which uses Google’s novel Griffin architecture. Griffin combines linear recurrences with local attention to achieve excellent performance on language. It has a fixed-sized state, which reduces memory use and enables efficient inference on long sequences. We provide a pre-trained model with 2B non-embedding parameters, and an instruction tuned variant. Both models achieve comparable performance to Gemma-2B despite being trained on fewer tokens.* + +Tips: + +- The original checkpoints can be converted using the conversion script `src/transformers/models/recurrent_gemma/convert_recurrent_gemma_weights_to_hf.py` + +This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). + + +## RecurrentGemmaConfig + +[[autodoc]] RecurrentGemmaConfig + + +## RecurrentGemmaModel + +[[autodoc]] RecurrentGemmaModel + - forward + +## RecurrentGemmaForCausalLM + +[[autodoc]] RecurrentGemmaForCausalLM + - forward + diff --git a/docs/source/en/tasks/language_modeling.md b/docs/source/en/tasks/language_modeling.md index e1858ef248..88a4519271 100644 --- a/docs/source/en/tasks/language_modeling.md +++ b/docs/source/en/tasks/language_modeling.md @@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the Choose one of the following architectures: -[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) +[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [RecurrentGemma](../model_doc/recurrent_gemma), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index da29d77972..06a6a0859b 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -743,6 +743,7 @@ _import_structure = { "RealmConfig", "RealmTokenizer", ], + "models.recurrent_gemma": ["RecurrentGemmaConfig"], "models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"], "models.regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"], "models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"], @@ -3115,6 +3116,13 @@ else: "load_tf_weights_in_realm", ] ) + _import_structure["models.recurrent_gemma"].extend( + [ + "RecurrentGemmaForCausalLM", + "RecurrentGemmaModel", + "RecurrentGemmaPreTrainedModel", + ] + ) _import_structure["models.reformer"].extend( [ "REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5625,6 +5633,7 @@ if TYPE_CHECKING: RealmConfig, RealmTokenizer, ) + from .models.recurrent_gemma import RecurrentGemmaConfig from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .models.regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig @@ -7687,6 +7696,11 @@ if TYPE_CHECKING: RealmScorer, load_tf_weights_in_realm, ) + from .models.recurrent_gemma import ( + RecurrentGemmaForCausalLM, + RecurrentGemmaModel, + RecurrentGemmaPreTrainedModel, + ) from .models.reformer import ( REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ReformerAttention, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 0599d3b876..4a5cd01add 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -187,6 +187,7 @@ from . import ( qwen2_moe, rag, realm, + recurrent_gemma, reformer, regnet, rembert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index bf46066002..0d5d9ae62b 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -198,6 +198,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("qwen2_moe", "Qwen2MoeConfig"), ("rag", "RagConfig"), ("realm", "RealmConfig"), + ("recurrent_gemma", "RecurrentGemmaConfig"), ("reformer", "ReformerConfig"), ("regnet", "RegNetConfig"), ("rembert", "RemBertConfig"), @@ -471,6 +472,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("qwen2_moe", "Qwen2MoE"), ("rag", "RAG"), ("realm", "REALM"), + ("recurrent_gemma", "RecurrentGemma"), ("reformer", "Reformer"), ("regnet", "RegNet"), ("rembert", "RemBERT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 150dea04f3..6f3d9d17a3 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -183,6 +183,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertModel"), ("qwen2", "Qwen2Model"), ("qwen2_moe", "Qwen2MoeModel"), + ("recurrent_gemma", "RecurrentGemmaModel"), ("reformer", "ReformerModel"), ("regnet", "RegNetModel"), ("rembert", "RemBertModel"), @@ -469,6 +470,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertLMHeadModel"), ("qwen2", "Qwen2ForCausalLM"), ("qwen2_moe", "Qwen2MoeForCausalLM"), + ("recurrent_gemma", "RecurrentGemmaForCausalLM"), ("reformer", "ReformerModelWithLMHead"), ("rembert", "RemBertForCausalLM"), ("roberta", "RobertaForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 4bc5d81053..af30469f9c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -363,6 +363,13 @@ else: ), ("rag", ("RagTokenizer", None)), ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)), + ( + "recurrent_gemma", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "reformer", ( diff --git a/src/transformers/models/recurrent_gemma/__init__.py b/src/transformers/models/recurrent_gemma/__init__.py new file mode 100644 index 0000000000..3ac7ff1c99 --- /dev/null +++ b/src/transformers/models/recurrent_gemma/__init__.py @@ -0,0 +1,59 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_recurrent_gemma": ["RecurrentGemmaConfig"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_recurrent_gemma"] = [ + "RecurrentGemmaForCausalLM", + "RecurrentGemmaModel", + "RecurrentGemmaPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_recurrent_gemma import RecurrentGemmaConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_recurrent_gemma import ( + RecurrentGemmaForCausalLM, + RecurrentGemmaModel, + RecurrentGemmaPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py new file mode 100644 index 0000000000..f5a3f9673a --- /dev/null +++ b/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" RecurrentGemma model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RecurrentGemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RecurrentGemmaModel`]. It is used to instantiate a RecurrentGemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the RecurrentGemma-7B. + + e.g. [google/recurrentgemma-2b](https://huggingface.co/google/recurrentgemma-2b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_hidden_layers (`int`, *optional*, defaults to 26): + The number of hidden layers in the model. + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the RecurrentGemma model. Defines the number of + different tokens that can be represented by the + `inputs_ids` passed when calling [`RecurrentGemmaModel`] + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 7680): + Dimension of the MLP representations. + num_attention_heads (`int`, *optional*, defaults to 10): + The number of heads for the attention block and the number of + heads/blocks for the block-diagonal layers used in the RG-LRU gates. + This number must divide `hidden_size` and `lru_width`. + lru_width (`int` or `None`, *optional*): + Dimension of the hidden representations of the RG-LRU. If `None` + this will be set to `hidden_size`. + Whether to scale the output of the embeddings by `sqrt(hidden_size)`. + attention_window_size (`int`, *optional*, defaults to 2048): + The size of the attention window used in the attention block. + conv1d_width (`int`, *optional*, defaults to 4): + The kernel size of conv1d layers used in the recurrent blocks. + logits_soft_cap (`float`, *optional*, defaults to 30.0): + The value at which the logits should be soft-capped to after the transformer and LM-head computation in the Causal LM architecture. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values + attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + hidden_activation (``str` or `function``, *optional*, defaults to `"gelu_pytorch_tanh"`): + The hidden activation used in the recurrent block as well as the MLP layer of the decoder layers. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + The partial rotary factor used in the initialization of the rotary embeddings. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + block_types (`List[str]`, *optional*, defaults to `('recurrent', 'recurrent', 'attention')`): + List of aleternating blocks that will be repeated to initialize the `temporal_block` layer. + attention_dropout (`float`, *optional*, defaults to 0.0): dropout value to use after the attention softmax. + num_key_value_heads (`16`, *optional*, defaults to 16): Number of key value heads to use GQA. + attention_bias (`bool`, *optional*, defaults to `False`): whether or not the linear q,k,v of the Attention layer should have bias + w_init_variance_scale (`float`, *optional*, defaults to 0.01): weight initialization variance. + ```python + >>> from transformers import RecurrentGemmaModel, RecurrentGemmaConfig + + >>> # Initializing a RecurrentGemma recurrentgemma-2b style configuration + >>> configuration = RecurrentGemmaConfig() + + >>> # Initializing a model from the recurrentgemma-2b style configuration + >>> model = RecurrentGemmaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "recurrent_gemma" + + def __init__( + self, + num_hidden_layers=26, + vocab_size=256000, + hidden_size=2560, + intermediate_size=3 * 2560, + num_attention_heads=10, + lru_width=None, + attention_window_size=2048, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + hidden_activation="gelu_pytorch_tanh", + partial_rotary_factor=0.5, + rope_theta=10000.0, + block_types=("recurrent", "recurrent", "attention"), + attention_dropout=0.0, + num_key_value_heads=None, + attention_bias=False, + w_init_variance_scale=0.01, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] diff --git a/src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py b/src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py new file mode 100644 index 0000000000..dc6619e217 --- /dev/null +++ b/src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py @@ -0,0 +1,222 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import warnings + +import torch +from accelerate import init_empty_weights + +from transformers import GemmaTokenizer, RecurrentGemmaConfig, RecurrentGemmaForCausalLM + + +try: + from transformers import GemmaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + GemmaTokenizerFast = None + +import regex as re + + +""" +Sample usage: + +``` +python src/transformers/models/gemma/convert_gemma_weights_to_hf.py \ + --input_dir /path/to/downloaded/gemma/weights --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import GemmaForCausalLM, GemmaTokenizerFast + +model = GemmaForCausalLM.from_pretrained("/output/path") +tokenizer = GemmaTokenizerFast.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + +gemma_2b_config = RecurrentGemmaConfig( + num_attention_heads=10, + num_key_value_heads=1, + hidden_size=2560, + intermediate_size=15360, + vocab_size=256000, + num_hidden_layers=26, +) + +gemma_7b_config = RecurrentGemmaConfig() + +CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config} +LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"} + + +def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False, dtype=torch.float32): + print(f"Fetching all parameters from the checkpoint at '{input_base_path}'") + model_state_dict = torch.load(input_base_path, map_location="cpu") + + REPLACEMENT = { + "blocks.": "layers.", + ".ffw_down.b": ".down_proj.b", + ".ffw_down.w": ".down_proj.w", + ".ffw_up.b": ".up_proj.bias", + ".ffw_up.w": ".up_proj.weight", + "recurrent_block": "temporal_block", + "attention_block": "temporal_block", + "temporal_block.proj_final": "temporal_block.out_proj", + "norm.scale": "norm.weight", + ".proj_k": ".k_proj", + ".proj_q": ".q_proj", + ".proj_v": ".v_proj", + ".proj_final": ".o_proj", + "embedder.input_embedding": "embed_tokens.weight", + "conv_1d.w": "conv_1d.weight", + "conv_1d.b": "conv_1d.bias", + "input_gate.w": "input_gate.weight", + "input_gate.b": "input_gate.bias", + "a_param": "recurrent_param", + "a_gate.b": "recurrent_gate.bias", + "a_gate.w": "recurrent_gate.weight", + } + + state_dict = {} + for k, v in model_state_dict.items(): + k = "model." + k + pattern = re.compile("|".join(map(re.escape, REPLACEMENT.keys()))) + key = pattern.sub(lambda match: REPLACEMENT[match.group(0)], k) + if "conv_1d.weight" in key: + v = v[:, None, :].transpose(0, 2) + if "up_proj.weight" in key: + state_dict[key.replace("up_proj", "gate_proj")] = v[0].T.contiguous() + v = v[1].T.contiguous() + if "up_proj.bias" in key: + state_dict[key.replace("up_proj", "gate_proj")] = v[0, 0, 0].clone() + v = v[1, 0, 0].contiguous() + if "recurrent_gate.bias" in key: + state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone() + elif "recurrent_gate.weight" in key: + state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone() + elif "input_gate.b" in key: + state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone() + elif "input_gate.w" in key: + state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone() + elif "embed_tokens" in key: + state_dict[key] = v[: config.vocab_size, :].contiguous().clone() + state_dict["lm_head.weight"] = v[: config.vocab_size, :].contiguous().clone() + else: + state_dict[key] = v.contiguous() + + torch.set_default_dtype(dtype) + + print("Loading the checkpoint in a Gemma model.") + with init_empty_weights(): + model = RecurrentGemmaForCausalLM(config) + model.load_state_dict(state_dict, assign=True, strict=True) + + model.config.torch_dtype = torch.float32 + del model.config._name_or_path + print("Saving in the Transformers format.") + + if push_to_hub: + print(f"pushing the model to {save_path}") + else: + model.save_pretrained(save_path, safe_serialization=safe_serialization) + + +def write_tokenizer(input_tokenizer_path, save_path, push_to_hub=False): + # Initialize the tokenizer based on the `spm` model + tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast + print(f"Saving a {tokenizer_class.__name__} to {save_path}.") + tokenizer = tokenizer_class(input_tokenizer_path) + if push_to_hub: + tokenizer.push_to_hub(save_path) + else: + tokenizer.save_pretrained(save_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_checkpoint", + help="Absolute path to the target Gemma weights.", + default="/home/arthur/transformers_recurrentgemma/google/recurrent-gemma-2b-it/ToBeDeleted/2b-it.pt", + ) + parser.add_argument( + "--tokenizer_checkpoint", + help="Location of Gemma tokenizer model", + ) + parser.add_argument( + "--model_size", + default="2B", + choices=["2B", "7B", "tokenizer_only"], + help="'f' models correspond to the finetuned versions, and are specific to the Gemma2 official release. For more details on Gemma2, checkout the original repo: https://huggingface.co/google/gemma-7b", + ) + parser.add_argument( + "--output_dir", + default="google/recurrent-gemma-2b-it-hf", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--pickle_serialization", + help="Whether or not to save using `safetensors`.", + action="store_true", + default=False, + ) + parser.add_argument( + "--convert_tokenizer", + help="Whether or not to convert the tokenizer as well.", + action="store_true", + default=False, + ) + parser.add_argument( + "--push_to_hub", + help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.", + action="store_true", + default=False, + ) + parser.add_argument( + "--dtype", + default="float32", + help="Target dtype of the converted model", + ) + args = parser.parse_args() + + if args.convert_tokenizer: + if args.tokenizer_checkpoint is None: + raise ValueError("Path to the tokenizer is required when passing --convert_tokenizer") + + spm_path = os.path.join(args.tokenizer_checkpoint) + write_tokenizer(spm_path, args.output_dir, args.push_to_hub) + + config = CONFIG_MAPPING[args.model_size] + dtype = getattr(torch, args.dtype) + write_model( + config=config, + input_base_path=args.input_checkpoint, + save_path=args.output_dir, + safe_serialization=not args.pickle_serialization, + push_to_hub=args.push_to_hub, + dtype=dtype, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py new file mode 100644 index 0000000000..26cdc437d0 --- /dev/null +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -0,0 +1,938 @@ +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch RecurrentGemma model.""" + +import math +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_recurrent_gemma import RecurrentGemmaConfig + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "RecurrentGemmaConfig" +_MAX_SQRT_GRADIENT = 1000.0 + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->RecurrentGemma +class RecurrentGemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst RecurrentGemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + +ALL_LAYERNORM_LAYERS.append(RecurrentGemmaRMSNorm) + + +class RecurrentGemmaRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000, device=None): + super().__init__() + self.dim = dim + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class RecurrentGemmaSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: RecurrentGemmaConfig): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.partial_rotary_factor = config.partial_rotary_factor + + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.rotary_emb = RecurrentGemmaRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + + # Partial rotary embedding + query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1) + key_rot, key_pass = torch.chunk(key_states, int(1 / self.partial_rotary_factor), dim=-1) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if use_cache and hasattr(self, "key_states"): + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_mask=causal_mask, # pretty much a must for sliding window backend! + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.head_dim**-0.5, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + + def _setup_cache(self, batch_size, device, dtype=None): + if dtype is None and self.config.torch_dtype is not None: + dtype = self.config.torch_dtype + dtype = dtype if dtype is not None else torch.float32 + cache_shape = (batch_size, self.num_key_value_heads, self.config.attention_window_size, self.head_dim) + self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) + self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) + + @torch.no_grad() + def _update_cache(self, key_states, value_states, **cache_kwargs): + """ + torch.compile compatible sliding window. + Computes the `indices` based on `cache_position >= self.config.attention_window_size - 1`. + The `to_shift` is only true once we are above attention_window_size. Thus with `attention_window_size==64`: + + indices = (slicing + to_shift[-1].int()-1) % self.config.attention_window_size + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `attention_window_size`) + """ + cache_position = cache_kwargs.get("cache_position") + if cache_position.shape[0] > self.config.attention_window_size: + # int indexing -> device sync? in compile, use tensor + k_out = key_states[:, :, -self.config.attention_window_size :, :] + v_out = value_states[:, :, -self.config.attention_window_size :, :] + else: + slicing = torch.ones( + self.config.attention_window_size, dtype=torch.long, device=value_states.device + ).cumsum(0) + cache_position = cache_position.clamp(0, self.config.attention_window_size - 1) + to_shift = cache_position >= self.config.attention_window_size - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size + + k_out, v_out = self.key_states, self.value_states + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_states, self.value_states = k_out, v_out + return k_out, v_out + + +class SqrtBoundDerivative(torch.autograd.Function): + """Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`.""" + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + """The forward pass, which is a normal `sqrt`.""" + ctx.save_for_backward(x) + return torch.sqrt(x) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """The backward pass, which clips the `sqrt` gradient.""" + (x,) = ctx.saved_tensors + clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2)) + return grad_output / torch.sqrt(clipped_x_times_4) + + +class RecurrentGemmaRglru(nn.Module): + """A Real-Gated Linear Recurrent Unit (RG-LRU) layer.""" + + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.block_width = config.lru_width // self.num_attention_heads + + self.recurrent_param = nn.Parameter(torch.empty([config.lru_width])) + self.input_gate_weight = nn.Parameter( + torch.empty([self.num_attention_heads, self.block_width, self.block_width]) + ) + self.input_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width])) + + self.recurrent_gate_weight = nn.Parameter( + torch.empty([self.num_attention_heads, self.block_width, self.block_width]) + ) + self.recurrent_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width])) + self.recurrent_states = None + + def forward( + self, + activations: torch.Tensor, + position_ids: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, lru_width = activations.shape + reset = position_ids[:, :, None] == 0 + + reshape_act = activations.reshape(batch_size * seq_len, self.num_attention_heads, self.block_width) + reshape_act = reshape_act.permute(1, 0, 2) + + res = torch.baddbmm(self.input_gate_bias[:, None, :], reshape_act, self.input_gate_weight) + input_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width)) + + res = torch.baddbmm(self.recurrent_gate_bias[:, None, :], reshape_act, self.recurrent_gate_weight) + recurrent_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width)) + + # Compute the parameter `A` of the recurrence. + log_recurrent_gate = -8.0 * recurrent_gate * nn.functional.softplus(self.recurrent_param) + recurrent_gate = torch.exp(log_recurrent_gate) + a_square = torch.exp(2 * log_recurrent_gate) + + # Gate the input. + gated_inputs = activations * input_gate + + # Apply gamma normalization to the input. We need to clip the derivatives of + # `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying + multiplier = 1 + tracing = isinstance(activations, torch.fx.Proxy) or ( + hasattr(torch, "_dynamo") and torch._dynamo.is_compiling() + ) + if not torch.jit.is_tracing() and not tracing: + multiplier = SqrtBoundDerivative.apply(1 - a_square) + multiplier = reset + ~reset * multiplier + normalized_x = gated_inputs * multiplier.type(activations.dtype) + + hidden_states, recurrent_states = self._rnn_scan( + hidden_states=normalized_x, + recurrent_gate=recurrent_gate, + reset=reset, + recurrent_states=self.recurrent_states, + ) + self.recurrent_states = recurrent_states + return hidden_states + + # TODO refactor + def _rnn_scan( + self, + hidden_states: torch.Tensor, + recurrent_gate: torch.Tensor, + reset: torch.Tensor, + recurrent_states: Union[torch.Tensor, None], + acc_dtype: torch.dtype = torch.float32, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Runs the recurrence of a linear RNN. + + Args: + hidden_states: The input sequence. + recurrent_gate: The diagonal of the recurrence matrix `A`. + reset: Indicator of document boundaries, e.g. when to reset the hidden state + of the RNN. + recurrent_states: The initial hidden state. + acc_dtype: The data type for the accumulation. + + Returns: + The output of the linear recurrence. + """ + # Multiply `a` by the reset. + recurrent_gate = recurrent_gate * ~reset + + if hidden_states.shape[1] == 1: + # Using scan in sampling mode. + if recurrent_states is None: # same here, when decoding you always have cache + return hidden_states, hidden_states[:, 0].type(acc_dtype) + + else: + contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None] + contextualized_states += hidden_states.type(acc_dtype) + return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1] + + else: + # Using scan in linear mode. + if recurrent_states is None: + recurrent_states = torch.zeros(hidden_states[:, 0].shape, dtype=acc_dtype, device=hidden_states.device) + + contextualized_states = torch.zeros_like(hidden_states) + for t in range(hidden_states.shape[1]): + recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states + recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype) + contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype) + + return contextualized_states, recurrent_states + + +class RecurrentGemmaRecurrentBlock(nn.Module): + """Griffin and Hawk's recurrent block.""" + + def __init__(self, config): + super().__init__() + self.lru_width = config.lru_width + self.hidden_size = config.hidden_size + self.linear_y = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width) + self.linear_x = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width) + self.linear_out = nn.Linear(in_features=config.lru_width, out_features=config.hidden_size) + self.conv1d_width = config.conv1d_width + self.conv_1d = nn.Conv1d( + config.lru_width, + config.lru_width, + kernel_size=config.conv1d_width, + groups=config.lru_width, + padding=config.conv1d_width - 1, + ) + self.rg_lru = RecurrentGemmaRglru(config) + self.act_fn = ACT2FN[config.hidden_activation] + + self.conv1d_state = None + + def forward( + self, + input_states: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + cache_position: torch.Tensor, + use_cache: bool = True, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + _, seq_len, _ = input_states.shape + + y_branch = self.linear_y(input_states) + y_branch = self.act_fn(y_branch) + + x_branch = self.linear_x(input_states) + x_branch = x_branch.transpose(1, 2) + + if use_cache: + if cache_position.shape[0] != 1: # prefill + self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0)) + x_branch = self.conv_1d(x_branch)[..., :seq_len] + else: # decoding + conv_state = torch.cat((self.conv1d_state, x_branch), -1) + x_branch = torch.sum(conv_state * self.conv_1d.weight[:, 0, :], dim=-1) + self.conv_1d.bias + x_branch = x_branch.unsqueeze(-1) + self.conv1d_state = conv_state[:, :, 1:] + else: + x_branch = self.conv_1d(x_branch)[..., :seq_len] + + x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids) + + hidden_states = x_branch * y_branch + hidden_states = self.linear_out(hidden_states) + return hidden_states + + def _setup_cache(self, batch, device, dtype): + # recurrent_states always computed in full precision + self.rg_lru.recurrent_states = torch.zeros((batch, self.lru_width), device=device, dtype=torch.float32) + self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype) + + +TEMPORAL_BLOCK_CLASSES = {"recurrent": RecurrentGemmaRecurrentBlock, "attention": RecurrentGemmaSdpaAttention} + + +class RecurrentGemmaMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, hidden_states): + gate = self.act_fn(self.gate_proj(hidden_states)) + return self.down_proj(gate * self.up_proj(hidden_states)) + + +class RecurrentGemmaDecoderLayer(nn.Module): + """Griffin and Hawk's residual block.""" + + def __init__(self, config, layer_idx): + super().__init__() + self.temporal_pre_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.temporal_block = TEMPORAL_BLOCK_CLASSES[config.layers_block_type[layer_idx]](config) + self.channel_pre_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp_block = RecurrentGemmaMlp(config) + + def forward( + self, + activations: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + cache_position: torch.Tensor = None, + use_cache: bool = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raw_activations = activations + inputs_normalized = self.temporal_pre_norm(raw_activations) # RMSNorm introduces slight slight differences + + hidden_states = self.temporal_block( + inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache + ) + + residual = hidden_states + raw_activations + + hidden_states = self.channel_pre_norm(residual) + hidden_states = self.mlp_block(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states + + +RECURRENTGEMMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RecurrentGemmaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare RecurrentGemma Model outputting raw hidden-states without any specific head on top.", + RECURRENTGEMMA_START_DOCSTRING, +) +class RecurrentGemmaPreTrainedModel(PreTrainedModel): + config_class = RecurrentGemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["RecurrentGemmaDecoderLayer"] + _skip_keys_device_placement = ["cache"] + _supports_flash_attn_2 = False + _supports_sdpa = False # we can't compare with eager for now + _supports_cache_class = True + + def _init_weights(self, module): + std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) + if isinstance(module, nn.Conv1d): + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + torch.nn.init.zeros_(module.bias) + elif isinstance(module, RecurrentGemmaSdpaAttention): + torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + + std = math.sqrt(self.config.final_w_init_variance_scale / self.config.hidden_size) + torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=std) + elif isinstance(module, RecurrentGemmaRecurrentBlock): + torch.nn.init.zeros_(module.linear_x.bias) + torch.nn.init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + + torch.nn.init.zeros_(module.linear_y.bias) + torch.nn.init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + + std = math.sqrt(self.config.final_w_init_variance_scale / self.config.lru_width) + torch.nn.init.normal_(module.linear_out.weight, mean=0.0, std=std) + torch.nn.init.zeros_(module.linear_out.bias) + elif isinstance(module, RecurrentGemmaRglru): + std = math.sqrt( + self.config.w_init_variance_scale / (self.config.lru_width // self.config.num_attention_heads) + ) + torch.nn.init.normal_(module.input_gate_weight, mean=0.0, std=std) + torch.nn.init.normal_(module.recurrent_gate_weight, mean=0.0, std=std) + torch.nn.init.zeros_(module.input_gate_bias) + torch.nn.init.zeros_(module.recurrent_gate_bias) + + module.recurrent_param.data.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8) + module.recurrent_param.data.log_().mul_(0.5) + module.recurrent_param.data.neg_().exp_().sub_(1.0).log_() + elif isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + if getattr(module, "bias", None) is not None: + torch.nn.init.zeros_(module.bias) + + def _setup_cache(self, config, batch, device, dtype): + layers = getattr(self, "model", self).layers + for layer in layers: + layer.temporal_block._setup_cache(batch, device, dtype) + + def reset_cache(self, batch, device, dtype): + pass + + +RECURRENTGEMMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare RecurrentGemma Model outputting raw hidden-states without any specific head on top.", + RECURRENTGEMMA_START_DOCSTRING, +) +class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`RecurrentGemmaDecoderLayer`] + + Args: + config: RecurrentGemmaConfig + """ + + def __init__(self, config: RecurrentGemmaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [RecurrentGemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.register_buffer("normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16)) + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings + def get_input_embeddings(self): + return self.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(RECURRENTGEMMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + if use_cache and inputs_embeds.shape[1] != 1: # TODO let's maybe only call in the `generate`? + self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + hidden_states = hidden_states * self.normalizer.type(hidden_states.dtype) + + all_hidden_states = () if output_hidden_states else None + for i, residual_block in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + residual_block.__call__, hidden_states, position_ids, causal_mask, cache_position, use_cache + ) + else: + hidden_states = residual_block(hidden_states, position_ids, causal_mask, cache_position, use_cache) + + hidden_states = self.final_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + # Ignore copy + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = max(self.config.attention_window_size, sequence_length) + + diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = diagonal + if sequence_length != 1: + causal_mask = torch.triu(diagonal, diagonal=-1) + + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if attention_mask is not None and attention_mask.device.type == "cuda": + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma +class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RecurrentGemmaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(RECURRENTGEMMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, CausalLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RecurrentGemmaForCausalLM + + >>> model = RecurrentGemmaForCausalLM.from_pretrained("google/recurrentgemma-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-2b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True + outputs = self.model( + input_ids=input_ids, + cache_position=cache_position, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + # Soft-cap the logits TODO remove if always done. + # if self.config.logits_soft_cap is not None: + cap = self.config.logits_soft_cap + logits = nn.functional.tanh(logits / cap) * cap + + logits = logits.float() + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + # Ignore copy + def prepare_inputs_for_generation( + self, input_ids, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=None, **kwargs + ): + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + attention_mask = attention_mask[:, -self.config.attention_window_size :] + + past_length = cache_position[0] + if past_length > 0: + position_ids = position_ids[:, past_length:] + + if inputs_embeds is not None: + model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]} + else: + model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()} + + if cache_position is not None: + cache_position = cache_position[-position_ids.shape[1] :] + + model_inputs.update( + { + "position_ids": position_ids, + "attention_mask": attention_mask, + "cache_position": cache_position, + "use_cache": use_cache, + } + ) + return model_inputs + + # Ignore copy + def _reorder_cache(self, past_key_values, beam_idx): + for layer in self.layers: + if hasattr(layer.temporal_block, "key_states"): + k_state = layer.temporal_block.key_states + v_state = layer.temporal_block.value_states + k_state = k_state.index_select(0, beam_idx.to(k_state.device)) + v_state = v_state.index_select(0, beam_idx.to(v_state.device)) + return None diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1bdab80a13..1c04fb9082 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7051,6 +7051,27 @@ def load_tf_weights_in_realm(*args, **kwargs): requires_backends(load_tf_weights_in_realm, ["torch"]) +class RecurrentGemmaForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RecurrentGemmaModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RecurrentGemmaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/recurrent_gemma/__init__.py b/tests/models/recurrent_gemma/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py new file mode 100644 index 0000000000..ae1d9e7079 --- /dev/null +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -0,0 +1,508 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch RecurrentGemma model. """ +import unittest + +from parameterized import parameterized + +from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed +from transformers.testing_utils import ( + require_bitsandbytes, + require_read_token, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import RecurrentGemmaForCausalLM, RecurrentGemmaModel + + +class RecurrentGemmaModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=12, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + num_hidden_layers=3, + vocab_size=99, + hidden_size=32, + intermediate_size=3 * 32, + num_attention_heads=2, + lru_width=2 * 32, + embeddings_scale_by_sqrt_dim=True, + attention_window_size=16, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + rope_theta=10000.0, + type_vocab_size=16, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return RecurrentGemmaConfig( + num_hidden_layers=self.num_hidden_layers, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_attention_heads=self.num_attention_heads, + lru_width=self.lru_width, + embeddings_scale_by_sqrt_dim=self.embeddings_scale_by_sqrt_dim, + attention_window_size=self.attention_window_size, + conv1d_width=self.conv1d_width, + logits_soft_cap=self.logits_soft_cap, + rms_norm_eps=self.rms_norm_eps, + use_cache=self.use_cache, + rope_theta=self.rope_theta, + pad_token_id=self.pad_token_id, + output_attentions=False, + ) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->RecurrentGemma + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = RecurrentGemmaModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->RecurrentGemma + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = RecurrentGemmaModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->RecurrentGemma + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = RecurrentGemmaForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->RecurrentGemma + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = RecurrentGemmaForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->RecurrentGemma + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else () + # all_generative_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else () #TODO @gante not fully supported + pipeline_model_mapping = ( + { + "feature-extraction": RecurrentGemmaModel, + "text-generation": RecurrentGemmaForCausalLM, + } + if is_torch_available() + else {} + ) + fx_compatible = False # FIXME let's try to support this @ArthurZucker + test_torchscript = False # FIXME let's try to support this @ArthurZucker + test_missing_keys = False + test_model_parallel = False + test_pruning = False + test_head_masking = False # RecurrentGemma does not have attention heads + test_model_parallel = False + + # Need to remove 0.9 in `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.6] + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name + ): + return True + + def setUp(self): + # We don't output attentions + self.has_attentions = False + self.model_tester = RecurrentGemmaModelTester(self) + self.config_tester = ConfigTester(self, config_class=RecurrentGemmaConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip("Recurrent gemma does not use legacy cache") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip("RecurrentGemma does not return pkv") + def test_past_key_values_format(self): + pass + + @unittest.skip("RecurrentGemma only supports sdpa") + def test_eager_matches_sdpa_generate(self): + pass + + @unittest.skip("RecurrentGemma only supports sdpa") + def test_eager_matches_sdpa_inference(self): + pass + + @unittest.skip("RecurrentGemma does not return the cache") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("RecurrentGemma does not return the cache") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("RecurrentGemma does not return the cache") + def test_contrastive_generate(self): + pass + + @unittest.skip("SQRBound is known to have issues with gc") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + def _check_attentions_for_generate(self, *args, **kwargs): + return True # Model does not return attention + + @unittest.skip("Past key values are not returned") + def test_prompt_lookup_decoding_matches_greedy_search(self): + pass + + @unittest.skip("Past key values are not returned") + def test_model_parallelism(self): + pass + + @unittest.skip("Past key values are not returned") + def test_model_parallel_beam_search(self): + pass + + def _check_past_key_values_for_generate(self, *args, **kwargs): + return True + + @unittest.skip("Rely on `past_key_values` to crop the assistant pkv. Not supported") + def test_assisted_decoding_matches_greedy_search(self): + pass + + @unittest.skip("RecurrentGemma's output different if you pad left or right. This is expected") + def test_left_padding_compatibility(self): + pass + + @unittest.skip("Relies on `past_key_values` returned by the model. Not supported with recurrent gemma") + def test_assisted_decoding_sample(self): + pass + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + seq_len = min_length + idx if not use_cache else 1 + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) + + @unittest.skip("TODO @arthurzucker not super important and failing.") + def test_initialization(self): + pass + + +@require_torch_gpu +@slow +class RecurrentGemmaIntegrationTest(unittest.TestCase): + input_text = ["Hello I am doing", "Hi today"] + model_id = "google/recurrentgemma-2b" + + @require_read_token + def test_2b_generate(self): + EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a very good day for you. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do.'] # fmt: skip + model = AutoModelForCausalLM.from_pretrained(self.model_id, low_cpu_mem_usage=True).to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + tokenizer.padding_side = "right" + + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=64, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + tokenizer.padding_side = "left" + EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I am going to share with you the best free online video editing software.\n\n

Best Free Online Video Editing Software

\n\n1. Wondershare Filmora\n\nWondershare Filmora is a free online video editing software that is used to edit videos.'] # fmt: skip + + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + output = model.generate(**inputs, max_new_tokens=64, do_sample=False) + del model + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + model = AutoModelForCausalLM.from_pretrained( + self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16 + ).to(torch_device) + output = model.generate(**inputs, max_new_tokens=64, do_sample=False) + del model + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_read_token + def test_2b_sample(self): + set_seed(0) + EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write P if the underlined word group is a phrase and NP if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip + model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + inputs = tokenizer("Where is Paris ?", return_tensors="pt", padding=True).to(torch_device) + output = model.generate(**inputs, max_new_tokens=128, do_sample=True) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXT) + + @require_bitsandbytes + @require_read_token + def test_model_2b_8bit(self): + EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to use "] # fmt: skip + + model = AutoModelForCausalLM.from_pretrained( + "gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16 + ) + + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_read_token + def test_long_context(self): + input_text = [ + 'Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.' + ] + EXPECTED_GENERATION = [ + ' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." "We are not aware of any video footage that could have been taken on board the plane," Delannoy said. "We are not aware of any video footage that could' + ] + + model = AutoModelForCausalLM.from_pretrained( + self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16 + ).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + inputs = tokenizer(input_text, return_tensors="pt").to(torch_device) + output = model.generate(**inputs, max_new_tokens=64, do_sample=False) + output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) + self.assertEqual(output_text, EXPECTED_GENERATION) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 140cd560e0..736ee681c3 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -34,6 +34,8 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING SPECIAL_CASES_TO_ALLOW = { # used to compute the property `self.chunk_length` "EncodecConfig": ["overlap"], + # used to compute the property `self.layers_block_type` + "RecurrentGemmaConfig": ["block_types"], # used as in the config to define `intermediate_size` "MambaConfig": ["expand"], # used as `self.bert_model = BertModel(config, ...)` diff --git a/utils/check_repo.py b/utils/check_repo.py index f577bf1507..13dcd6ad97 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -86,6 +86,7 @@ PRIVATE_MODELS = [ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested + "RecurrentGemmaModel", # Building part of bigger (tested) model. "FuyuForCausalLM", # Not tested fort now "InstructBlipQFormerModel", # Building part of bigger (tested) model. "UMT5EncoderModel", # Building part of bigger (tested) model. diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index b57ed3b791..4ac104ee2e 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -768,6 +768,7 @@ src/transformers/models/rag/modeling_tf_rag.py src/transformers/models/rag/retrieval_rag.py src/transformers/models/realm/modeling_realm.py src/transformers/models/realm/retrieval_realm.py +src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py src/transformers/models/regnet/configuration_regnet.py src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py