From 5dfd407b37ac683dc91637e9913b0ae9205d2acd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 2 Jun 2023 10:30:24 +0100 Subject: [PATCH] [MMS] Scaling Speech Technology to 1,000+ Languages | Add attention adapter to Wav2Vec2 (#23813) * add fine-tuned with adapter layer * Add set_target_lang to tokenizer * Implement load adapter * add tests * make style * Apply suggestions from code review * Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py * make fix-copies * Apply suggestions from code review * make fix-copies * make style again * mkae style again * fix doc string * Update tests/models/wav2vec2/test_tokenization_wav2vec2.py * Apply suggestions from code review * fix * Correct wav2vec2 adapter * mkae style * Update src/transformers/models/wav2vec2/modeling_wav2vec2.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * add more nice docs * finish * finish * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Apply suggestions from code review * all finish --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- README.md | 1 + README_es.md | 1 + README_hd.md | 1 + README_ja.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.mdx | 1 + docs/source/en/model_doc/mms.mdx | 118 +++++++++ .../models/auto/configuration_auto.py | 1 + .../models/hubert/modeling_hubert.py | 43 +++- src/transformers/models/sew/modeling_sew.py | 9 +- .../models/sew_d/modeling_sew_d.py | 9 +- .../models/unispeech/modeling_unispeech.py | 43 +++- .../unispeech_sat/modeling_unispeech_sat.py | 43 +++- .../models/wav2vec2/configuration_wav2vec2.py | 5 + ..._original_pytorch_checkpoint_to_pytorch.py | 99 ++++++-- .../models/wav2vec2/modeling_wav2vec2.py | 233 +++++++++++++++++- .../models/wav2vec2/tokenization_wav2vec2.py | 41 ++- .../modeling_wav2vec2_conformer.py | 9 +- .../models/wavlm/modeling_wavlm.py | 9 +- .../models/wav2vec2/test_modeling_wav2vec2.py | 139 +++++++++++ .../wav2vec2/test_tokenization_wav2vec2.py | 45 ++++ 24 files changed, 823 insertions(+), 33 deletions(-) create mode 100644 docs/source/en/model_doc/mms.mdx diff --git a/README.md b/README.md index e8e3f26d0a..c6371abd7b 100644 --- a/README.md +++ b/README.md @@ -401,6 +401,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. 1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (from Alibaba Research) released with the paper [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) by Peng Wang, Cheng Da, and Cong Yao. 1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. +1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. 1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (from Google Inc.) released with the paper [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) by Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam. 1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (from Google Inc.) released with the paper [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) by Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. diff --git a/README_es.md b/README_es.md index 8d971e6f30..bd503bb717 100644 --- a/README_es.md +++ b/README_es.md @@ -376,6 +376,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. 1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (from Alibaba Research) released with the paper [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) by Peng Wang, Cheng Da, and Cong Yao. 1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. +1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. 1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (from Google Inc.) released with the paper [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) by Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam. 1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (from Google Inc.) released with the paper [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) by Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. diff --git a/README_hd.md b/README_hd.md index bb94c9961a..731e9975b5 100644 --- a/README_hd.md +++ b/README_hd.md @@ -348,6 +348,7 @@ conda install -c huggingface transformers 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (NVIDIA से) साथ वाला पेपर [Megatron-LM: ट्रेनिंग मल्टी-बिलियन पैरामीटर लैंग्वेज मॉडल्स यूजिंग मॉडल पैरेललिज़्म] (https://arxiv.org/abs/1909.08053) मोहम्मद शोएबी, मोस्टोफा पटवारी, राउल पुरी, पैट्रिक लेग्रेस्ले, जेरेड कैस्पर और ब्रायन कैटानज़ारो द्वारा पोस्ट किया गया। 1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (Alibaba Research से) Peng Wang, Cheng Da, and Cong Yao. द्वाराअनुसंधान पत्र [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) के साथ जारी किया गया 1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (फ्रॉम Studio Ousia) साथ में पेपर [mLUKE: द पावर ऑफ एंटिटी रिप्रेजेंटेशन इन मल्टीलिंगुअल प्रीट्रेन्ड लैंग्वेज मॉडल्स](https://arxiv.org/abs/2110.08151) रयोकन री, इकुया यामाडा, और योशिमासा त्सुरोका द्वारा। +1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (Facebook से) Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. द्वाराअनुसंधान पत्र [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) के साथ जारी किया गया 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (सीएमयू/गूगल ब्रेन से) साथ में कागज [मोबाइलबर्ट: संसाधन-सीमित उपकरणों के लिए एक कॉम्पैक्ट टास्क-अज्ञेय बीईआरटी] (https://arxiv.org/abs/2004.02984) Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, और Denny Zhou द्वारा पोस्ट किया गया। 1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (from Google Inc.) released with the paper [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) by Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam. 1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (from Google Inc.) released with the paper [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) by Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. diff --git a/README_ja.md b/README_ja.md index 25a2cdb1d5..9559f4d85e 100644 --- a/README_ja.md +++ b/README_ja.md @@ -410,6 +410,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (NVIDIA から) Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro から公開された研究論文: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (Alibaba Research から) Peng Wang, Cheng Da, and Cong Yao. から公開された研究論文 [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) 1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (Studio Ousia から) Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka から公開された研究論文: [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) +1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (Facebook から) Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. から公開された研究論文 [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (CMU/Google Brain から) Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou から公開された研究論文: [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) 1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (Google Inc. から) Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam から公開された研究論文: [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) 1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (Google Inc. から) Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen から公開された研究論文: [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) diff --git a/README_ko.md b/README_ko.md index 7c160536f1..f4203e6711 100644 --- a/README_ko.md +++ b/README_ko.md @@ -325,6 +325,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (NVIDIA 에서) Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 의 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 논문과 함께 발표했습니다. 1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (Alibaba Research 에서 제공)은 Peng Wang, Cheng Da, and Cong Yao.의 [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592)논문과 함께 발표했습니다. 1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (Studio Ousia 에서) Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka 의 [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) 논문과 함께 발표했습니다. +1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (Facebook 에서 제공)은 Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli.의 [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516)논문과 함께 발표했습니다. 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (CMU/Google Brain 에서) Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou 의 [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) 논문과 함께 발표했습니다. 1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (Google Inc. 에서) Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam 의 [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) 논문과 함께 발표했습니다. 1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (Google Inc. 에서) Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen 의 [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 75353804d3..2601f280e7 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -349,6 +349,7 @@ conda install -c huggingface transformers 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。 1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (来自 Alibaba Research) 伴随论文 [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) 由 Peng Wang, Cheng Da, and Cong Yao 发布。 1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (来自 Studio Ousia) 伴随论文 [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) 由 Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka 发布。 +1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (来自 Facebook) 伴随论文 [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) 由 Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli 发布。 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (来自 CMU/Google Brain) 伴随论文 [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) 由 Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou 发布。 1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (来自 Google Inc.) 伴随论文 [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) 由 Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam 发布。 1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (来自 Google Inc.) 伴随论文 [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 由 Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 066604362c..cd37ca9d14 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -361,6 +361,7 @@ conda install -c huggingface transformers 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. 1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (from Alibaba Research) released with the paper [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) by Peng Wang, Cheng Da, and Cong Yao. 1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. +1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. 1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (from Google Inc.) released with the paper [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) by Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam. 1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (from Google Inc.) released with the paper [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) by Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index dfe899032e..8b43fb5177 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -543,6 +543,8 @@ title: Hubert - local: model_doc/mctct title: MCTCT + - local: model_doc/mms + title: MMS - local: model_doc/sew title: SEW - local: model_doc/sew-d diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index b7442050c2..edf5bc00b3 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -162,6 +162,7 @@ The documentation is organized into five sections: 1. **[Megatron-GPT2](model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. 1. **[MGP-STR](model_doc/mgp-str)** (from Alibaba Research) released with the paper [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) by Peng Wang, Cheng Da, and Cong Yao. 1. **[mLUKE](model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. +1. **[MMS](model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. 1. **[MobileBERT](model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. 1. **[MobileNetV1](model_doc/mobilenet_v1)** (from Google Inc.) released with the paper [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) by Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam. 1. **[MobileNetV2](model_doc/mobilenet_v2)** (from Google Inc.) released with the paper [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) by Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. diff --git a/docs/source/en/model_doc/mms.mdx b/docs/source/en/model_doc/mms.mdx new file mode 100644 index 0000000000..bd32617370 --- /dev/null +++ b/docs/source/en/model_doc/mms.mdx @@ -0,0 +1,118 @@ + + +# MMS + +## Overview + +The MMS model was proposed in [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2111.09296) +by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli + +The abstract from the paper is the following: + +*Expanding the language coverage of speech technology has the potential to improve access to information for many more people. +However, current speech technology is restricted to about one hundred languages which is a small fraction of the over 7,000 +languages spoken around the world. +The Massively Multilingual Speech (MMS) project increases the number of supported languages by 10-40x, depending on the task. +The main ingredients are a new dataset based on readings of publicly available religious texts and effectively leveraging +self-supervised learning. We built pre-trained wav2vec 2.0 models covering 1,406 languages, +a single multilingual automatic speech recognition model for 1,107 languages, speech synthesis models +for the same number of languages, as well as a language identification model for 4,017 languages. +Experiments show that our multilingual speech recognition model more than halves the word error rate of +Whisper on 54 languages of the FLEURS benchmark while being trained on a small fraction of the labeled data.* + +Tips: + +- MMS is a speech model that accepts a float array corresponding to the raw waveform of the speech signal. The raw waveform should be pre-processed with [`Wav2Vec2FeatureExtractor`]. +- MMS model was trained using connectionist temporal classification (CTC) so the model output has to be decoded using + [`Wav2Vec2CTCTokenizer`]. +- MMS can load different language adapter weights for different languages via [`~Wav2Vec2PreTrainedModel.load_adapter`]. Language adapters only consists of roughly 2 million parameters + and can therefore be efficiently loaded on the fly when needed. + +Relevant checkpoints can be found under https://huggingface.co/models?other=mms. + +MMS's architecture is based on the Wav2Vec2 model, so one can refer to [Wav2Vec2's documentation page](wav2vec2). + +The original code can be found [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms). + +## Inference + +By default MMS loads adapter weights for English, but those can be easily switched out for another language. +Let's look at an example. + +First, we load audio data in different languages using the [Datasets](https://github.com/huggingface/datasets). + +```py +from datasets import load_dataset, Audio + +# English +stream_data = load_dataset("mozilla-foundation/common_voice_13_0", "en", split="test", streaming=True) +stream_data = stream_data.cast_column("audio", Audio(sampling_rate=16000)) +en_sample = next(iter(stream_data))["audio"]["array"] + +# French +stream_data = load_dataset("mozilla-foundation/common_voice_13_0", "fr", split="test", streaming=True) +stream_data = stream_data.cast_column("audio", Audio(sampling_rate=16000)) +fr_sample = next(iter(stream_data))["audio"]["array"] +``` + +Next, we load the model and processor + +```py +from transformers import Wav2Vec2ForCTC, AutoProcessor +import torch + +model_id = "facebook/mms-1b-all" + +processor = AutoProcessor.from_pretrained(model_id) +model = Wav2Vec2ForCTC.from_pretrained(model_id) +``` + +Now we process the audio data, pass the processed audio data to the model and transcribe the model output, +just like we usually do for [`Wav2Vec2ForCTC`]. + +```py +inputs = processor(en_sample, sampling_rate=16_000, return_tensors="pt") + +with torch.no_grad(): + outputs = model(**inputs).logits + +ids = torch.argmax(outputs, dim=-1)[0] +transcription = processor.decode(ids) +# 'joe keton disapproved of films and buster also had reservations about the media' +``` + +We can now keep the same model in memory and simply switch out the language adapters by +calling the convenient [`~Wav2Vec2ForCTC.load_adapter`] function for the model and [`~Wav2Vec2CTCTokenizer.set_target_lang`] for the tokenizer. +We pass the target language as an input - `"fra"` for French. + +```py +processor.tokenizer.set_target_lang("fra") +model.load_adapter("fra") + +inputs = processor(fr_sample, sampling_rate=16_000, return_tensors="pt") + +with torch.no_grad(): + outputs = model(**inputs).logits + +ids = torch.argmax(outputs, dim=-1)[0] +transcription = processor.decode(ids) +# "ce dernier est volé tout au long de l'histoire romaine" +``` + +In the same way the language can be switched out for all other supported languages. Please have a look at: + +```py +processor.tokenizer.vocab.keys() +``` + +to see all supported languages. diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index fbffe83822..2c05989d99 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -515,6 +515,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("megatron_gpt2", "Megatron-GPT2"), ("mgp-str", "MGP-STR"), ("mluke", "mLUKE"), + ("mms", "MMS"), ("mobilebert", "MobileBERT"), ("mobilenet_v1", "MobileNetV1"), ("mobilenet_v2", "MobileNetV2"), diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 1c24938256..a62b0180e9 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -603,6 +603,32 @@ class HubertEncoderLayer(nn.Module): return outputs +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->Hubert +class HubertAttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert class HubertEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): @@ -618,6 +644,11 @@ class HubertEncoderLayerStableLayerNorm(nn.Module): self.feed_forward = HubertFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = HubertAttnAdapterLayer(config) + else: + self.adapter_layer = None + def forward( self, hidden_states: torch.Tensor, @@ -633,6 +664,9 @@ class HubertEncoderLayerStableLayerNorm(nn.Module): hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + outputs = (hidden_states,) if output_attentions: @@ -1096,7 +1130,7 @@ class HubertModel(HubertPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT class HubertForCTC(HubertPreTrainedModel): - def __init__(self, config): + def __init__(self, config, target_lang=None): super().__init__(config) self.hubert = HubertModel(config) @@ -1114,6 +1148,13 @@ class HubertForCTC(HubertPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang) + # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 3e41496ea1..1d45facf22 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -969,7 +969,7 @@ class SEWModel(SEWPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW class SEWForCTC(SEWPreTrainedModel): - def __init__(self, config): + def __init__(self, config, target_lang=None): super().__init__(config) self.sew = SEWModel(config) @@ -987,6 +987,13 @@ class SEWForCTC(SEWPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang) + # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 417fd81c6e..f1642fab1a 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1509,7 +1509,7 @@ class SEWDModel(SEWDPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD class SEWDForCTC(SEWDPreTrainedModel): - def __init__(self, config): + def __init__(self, config, target_lang=None): super().__init__(config) self.sew_d = SEWDModel(config) @@ -1527,6 +1527,13 @@ class SEWDForCTC(SEWDPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang) + # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 71dcaad119..d7a10fbf59 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -639,6 +639,32 @@ class UniSpeechEncoderLayer(nn.Module): return outputs +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->UniSpeech +class UniSpeechAttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->UniSpeech class UniSpeechEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): @@ -654,6 +680,11 @@ class UniSpeechEncoderLayerStableLayerNorm(nn.Module): self.feed_forward = UniSpeechFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = UniSpeechAttnAdapterLayer(config) + else: + self.adapter_layer = None + def forward( self, hidden_states: torch.Tensor, @@ -669,6 +700,9 @@ class UniSpeechEncoderLayerStableLayerNorm(nn.Module): hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + outputs = (hidden_states,) if output_attentions: @@ -1340,7 +1374,7 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH class UniSpeechForCTC(UniSpeechPreTrainedModel): - def __init__(self, config): + def __init__(self, config, target_lang=None): super().__init__(config) self.unispeech = UniSpeechModel(config) @@ -1358,6 +1392,13 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang) + # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index ce3f5b80ca..d2b2584530 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -653,6 +653,32 @@ class UniSpeechSatEncoderLayer(nn.Module): return outputs +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->UniSpeechSat +class UniSpeechSatAttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->UniSpeechSat class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): @@ -668,6 +694,11 @@ class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): self.feed_forward = UniSpeechSatFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = UniSpeechSatAttnAdapterLayer(config) + else: + self.adapter_layer = None + def forward( self, hidden_states: torch.Tensor, @@ -683,6 +714,9 @@ class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + outputs = (hidden_states,) if output_attentions: @@ -1347,7 +1381,7 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): - def __init__(self, config): + def __init__(self, config, target_lang=None): super().__init__(config) self.unispeech_sat = UniSpeechSatModel(config) @@ -1365,6 +1399,13 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang) + # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index 6f7709e535..3404930573 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -180,6 +180,9 @@ class Wav2Vec2Config(PretrainedConfig): num_adapter_layers (`int`, *optional*, defaults to 3): Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is True`. + adapter_attn_dim (`int`, *optional*): + Dimension of the attention adapter weights to be used in each attention block. An example of a model using + attention adapters is [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all). output_hidden_size (`int`, *optional*): Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant if `add_adapter is True`. @@ -256,6 +259,7 @@ class Wav2Vec2Config(PretrainedConfig): adapter_stride=2, num_adapter_layers=3, output_hidden_size=None, + adapter_attn_dim=None, **kwargs, ): super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) @@ -326,6 +330,7 @@ class Wav2Vec2Config(PretrainedConfig): self.adapter_stride = adapter_stride self.num_adapter_layers = num_adapter_layers self.output_hidden_size = output_hidden_size or hidden_size + self.adapter_attn_dim = adapter_attn_dim # SequenceClassification-specific parameter. Feel free to ignore for other classes. self.classifier_proj_size = classifier_proj_size diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py index 4656f5b811..3e9cb3c030 100644 --- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py @@ -49,6 +49,7 @@ MAPPING = { "fc2": "encoder.layers.*.feed_forward.output_dense", "final_layer_norm": "encoder.layers.*.final_layer_norm", "encoder.layer_norm": "encoder.layer_norm", + "adapter_layer": "encoder.layers.*.adapter_layer", "w2v_model.layer_norm": "feature_projection.layer_norm", "quantizer.weight_proj": "quantizer.weight_proj", "quantizer.vars": "quantizer.codevectors", @@ -66,12 +67,26 @@ TOP_LEVEL_KEYS = [ ] -def set_recursively(hf_pointer, key, value, full_name, weight_type): +def set_recursively(key, value, full_name, weight_type, hf_pointer): for attribute in key.split("."): hf_pointer = getattr(hf_pointer, attribute) - if weight_type is not None: + hf_param_name = None + for param_key in PARAM_MAPPING.keys(): + if full_name.endswith(param_key): + hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]] + weight_type = "param" + + if weight_type is not None and weight_type != "param": hf_shape = getattr(hf_pointer, weight_type).shape + elif weight_type is not None and weight_type == "param": + shape_pointer = hf_pointer + for attribute in hf_param_name.split("."): + shape_pointer = getattr(shape_pointer, attribute) + hf_shape = shape_pointer.shape + + # let's reduce dimension + value = value[0] else: hf_shape = hf_pointer.shape @@ -89,12 +104,71 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type): hf_pointer.weight_v.data = value elif weight_type == "bias": hf_pointer.bias.data = value + elif weight_type == "param": + for attribute in hf_param_name.split("."): + hf_pointer = getattr(hf_pointer, attribute) + hf_pointer.data = value else: hf_pointer.data = value logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") +def rename_dict(key, value, full_name, weight_type, hf_dict): + hf_param_name = None + for param_key in PARAM_MAPPING.keys(): + if full_name.endswith(param_key): + hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]] + weight_type = "param" + + if weight_type is not None and weight_type != "param": + full_key = ".".join([key, weight_type]) + elif weight_type is not None and weight_type == "param": + full_key = ".".join([key, hf_param_name]) + else: + full_key = key + + hf_dict[full_key] = value if "lm_head" in full_key else value[0] + + +PARAM_MAPPING = { + "W_a": "linear_1.weight", + "W_b": "linear_2.weight", + "b_a": "linear_1.bias", + "b_b": "linear_2.bias", + "ln_W": "norm.weight", + "ln_b": "norm.bias", +} + + +def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None): + is_used = False + for key, mapped_key in MAPPING.items(): + mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + # TODO: don't match quantizer.weight_proj + weight_type = "weight" + else: + weight_type = None + if hf_dict is not None: + rename_dict(mapped_key, value, name, weight_type, hf_dict) + else: + set_recursively(mapped_key, value, name, weight_type, hf_model) + return is_used + return is_used + + def recursively_load_weights(fairseq_model, hf_model, is_headless): unused_weights = [] fairseq_dict = fairseq_model.state_dict() @@ -113,26 +187,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_headless): ) is_used = True else: - for key, mapped_key in MAPPING.items(): - mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key - if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: - is_used = True - if "*" in mapped_key: - layer_index = name.split(key)[0].split(".")[-2] - mapped_key = mapped_key.replace("*", layer_index) - if "weight_g" in name: - weight_type = "weight_g" - elif "weight_v" in name: - weight_type = "weight_v" - elif "bias" in name: - weight_type = "bias" - elif "weight" in name: - # TODO: don't match quantizer.weight_proj - weight_type = "weight" - else: - weight_type = None - set_recursively(hf_model, mapped_key, value, name, weight_type) - continue + is_used = load_wav2vec2_layer(name, value, hf_model) if not is_used: unused_weights.append(name) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index bcf7b7a3ad..51811dfdd3 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -42,12 +42,21 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + cached_file, + is_safetensors_available, logging, replace_return_docstrings, ) from .configuration_wav2vec2 import Wav2Vec2Config +WAV2VEC2_ADAPTER_PT_FILE = "adapter.{}.bin" +WAV2VEC2_ADAPTER_SAFE_FILE = "adapter.{}.safetensors" + +if is_safetensors_available(): + from safetensors.torch import load_file as safe_load_file + + logger = logging.get_logger(__name__) @@ -708,6 +717,11 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): self.feed_forward = Wav2Vec2FeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = Wav2Vec2AttnAdapterLayer(config) + else: + self.adapter_layer = None + def forward( self, hidden_states: torch.Tensor, @@ -723,6 +737,9 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + outputs = (hidden_states,) if output_attentions: @@ -1034,6 +1051,31 @@ class Wav2Vec2AdapterLayer(nn.Module): return hidden_states +class Wav2Vec2AttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + class Wav2Vec2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1132,6 +1174,188 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): module.gradient_checkpointing = value + @property + def _adapters(self): + if self.config.adapter_attn_dim is None: + raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.") + + adapter_weights = {} + for name, module in self.named_modules(): + if isinstance(module, Wav2Vec2AttnAdapterLayer): + for param_name, param in module.named_parameters(): + adapter_weights[".".join([name, param_name])] = param + + if isinstance(self, Wav2Vec2ForCTC): + for name, param in self.lm_head.named_parameters(): + adapter_weights[".".join(["lm_head", name])] = param + + return adapter_weights + + def load_adapter(self, target_lang: str, **kwargs): + r""" + Load a language adapter model from a pre-trained adapter model. + + Parameters: + target_lang (`str`): + Has to be a language id of an existing adapter weight. Adapter weights are stored in the format + adapter..safetensors or adapter..bin + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/". + + + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + Examples: + + ```python + >>> from transformers import Wav2Vec2ForCTC, AutoProcessor + + >>> ckpt = "facebook/mms-1b-all" + >>> processor = AutoProcessor.from_pretrained(ckpt) + >>> model = Wav2Vec2ForCTC.from_pretrained(ckpt, target_lang="eng") + >>> # set specific language + >>> processor.tokenizer.set_target_lang("spa") + >>> model.load_adapter("spa") + ``` + """ + if self.config.adapter_attn_dim is None: + raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.") + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + model_path_or_id = self.config._name_or_path + state_dict = None + + # 1. Let's first try loading a safetensors adapter weight + if use_safetensors is not False: + filepath = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang) + + try: + weight_path = cached_file( + model_path_or_id, + filename=filepath, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + ) + + state_dict = safe_load_file(weight_path) + + except EnvironmentError: + if use_safetensors: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + + except Exception: + # For any other exception, we throw a generic error. + if use_safetensors: + raise EnvironmentError( + f"Can't load the model for '{model_path_or_id}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a" + f" directory containing a file named {filepath}." + ) + + # 2. If this didn't work let's try loading a PyTorch adapter weight + if state_dict is None: + filepath = WAV2VEC2_ADAPTER_PT_FILE.format(target_lang) + + try: + weight_path = cached_file( + model_path_or_id, + filename=filepath, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + ) + + state_dict = torch.load(weight_path, map_location="cpu") + + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{model_path_or_id}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a" + f" directory containing a file named {filepath}." + ) + + adapter_weights = self._adapters + unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys()) + missing_keys = set(adapter_weights.keys()) - set(state_dict.keys()) + + if len(unexpected_keys) > 0: + raise ValueError(f"The adapter weights {weight_path} has unexpected keys: {', '.join(unexpected_keys)}.") + elif len(missing_keys) > 0: + raise ValueError(f"The adapter weights {weight_path} has missing keys: {', '.join(missing_keys)}.") + + # make sure now vocab size is correct + target_vocab_size = state_dict["lm_head.weight"].shape[0] + if target_vocab_size != self.config.vocab_size: + self.lm_head = nn.Linear( + self.config.output_hidden_size, target_vocab_size, device=self.device, dtype=self.dtype + ) + self.config.vocab_size = target_vocab_size + + # make sure that adapter weights are put in exactly the same precision and device placement and overwritten adapter weights + state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()} + self.load_state_dict(state_dict, strict=False) + WAV_2_VEC_2_START_DOCSTRING = r""" Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech @@ -1614,7 +1838,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): WAV_2_VEC_2_START_DOCSTRING, ) class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): - def __init__(self, config): + def __init__(self, config, target_lang=None): super().__init__(config) self.wav2vec2 = Wav2Vec2Model(config) @@ -1632,6 +1856,13 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang) + # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 15d3471da0..472fd2d649 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -148,6 +148,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): The token used for defining the end of a word. do_lower_case (`bool`, *optional*, defaults to `False`): Whether or not to accept lowercase input and lowercase the output when decoding. + target_lang (`str`, *optional*): + A target language the tokenizer should set by default. `target_lang` has to be defined for multi-lingual, + nested vocabulary such as [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all). **kwargs Additional keyword arguments passed along to [`PreTrainedTokenizer`] @@ -168,6 +171,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): word_delimiter_token="|", replace_word_delimiter_char=" ", do_lower_case=False, + target_lang=None, **kwargs, ): super().__init__( @@ -178,6 +182,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): do_lower_case=do_lower_case, word_delimiter_token=word_delimiter_token, replace_word_delimiter_char=replace_word_delimiter_char, + target_lang=target_lang, **kwargs, ) @@ -185,9 +190,18 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): self.do_lower_case = do_lower_case self.replace_word_delimiter_char = replace_word_delimiter_char + self.target_lang = target_lang with open(vocab_file, encoding="utf-8") as vocab_handle: - self.encoder = json.load(vocab_handle) + self.vocab = json.load(vocab_handle) + + # if target lang is defined vocab must be a nested dict + # with each target lang being one vocabulary + if target_lang is not None: + self.encoder = self.vocab[target_lang] + else: + self.encoder = self.vocab + self.decoder = {v: k for k, v in self.encoder.items()} # make sure that tokens made of several @@ -198,6 +212,27 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): self._create_trie(self.unique_no_split_tokens) + def set_target_lang(self, target_lang: str): + """ + Set the target language of a nested multi-lingual dictionary + """ + if self.vocab == self.encoder: + raise ValueError(f"{self.vocab} is not a multi-lingual, nested tokenizer. Cannot set target language.") + + if target_lang not in self.vocab: + raise ValueError(f"{target_lang} does not exist. Choose one of {', '.join(self.vocab.keys())}.") + + self.target_lang = target_lang + self.init_kwargs["target_lang"] = target_lang + self.encoder = self.vocab[target_lang] + self.decoder = {v: k for k, v in self.encoder.items()} + + # make sure that tokens made of several + # characters are not split at tokenization + for token in self.encoder.keys(): + if len(token) > 1: + self.unique_no_split_tokens.append(token) + @property def word_delimiter_token(self) -> str: """ @@ -231,7 +266,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): return len(self.decoder) def get_vocab(self) -> Dict: - return dict(self.encoder, **self.added_tokens_encoder) + return dict(self.vocab, **self.added_tokens_encoder) def _tokenize(self, text, **kwargs): """ @@ -606,7 +641,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ) with open(vocab_file, "w", encoding="utf-8") as f: - f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + f.write(json.dumps(self.vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n") return (vocab_file,) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 94ce66e4f3..594ca15b69 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1600,7 +1600,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ) class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer - def __init__(self, config): + def __init__(self, config, target_lang=None): super().__init__(config) self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) @@ -1618,6 +1618,13 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang) + # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 718c4d61dd..0dabcd234e 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -1268,7 +1268,7 @@ class WavLMModel(WavLMPreTrainedModel): ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM class WavLMForCTC(WavLMPreTrainedModel): - def __init__(self, config): + def __init__(self, config, target_lang=None): super().__init__(config) self.wavlm = WavLMModel(config) @@ -1286,6 +1286,13 @@ class WavLMForCTC(WavLMPreTrainedModel): ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang) + # Initialize weights and apply final processing self.post_init() diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index f3e93b670c..8fc82eb96e 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -54,6 +54,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch + from safetensors.torch import save_file as safe_save_file from transformers import ( Wav2Vec2FeatureExtractor, @@ -67,6 +68,8 @@ if is_torch_available(): Wav2Vec2Processor, ) from transformers.models.wav2vec2.modeling_wav2vec2 import ( + WAV2VEC2_ADAPTER_PT_FILE, + WAV2VEC2_ADAPTER_SAFE_FILE, Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices, _sample_negative_indices, @@ -290,6 +293,17 @@ class Wav2Vec2ModelTester: (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size), ) + def create_and_check_model_with_attn_adapter(self, config, input_values, attention_mask): + config.adapter_attn_dim = 16 + model = Wav2Vec2ForCTC(config=config) + + self.parent.assertIsNotNone(model._adapters) + + model.to(torch_device) + model.eval() + result = model(input_values, attention_mask=attention_mask) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size)) + def create_and_check_batch_inference(self, config, input_values, *args): # test does not pass for models making use of `group_norm` # check: https://github.com/pytorch/fairseq/issues/3227 @@ -844,6 +858,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs) + def test_model_with_attn_adapter(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_with_attn_adapter(*config_and_inputs) + def test_batched_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_batch_inference(*config_and_inputs) @@ -1098,6 +1116,85 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): def test_feed_forward_chunking(self): pass + def test_load_attn_adapter(self): + processor = Wav2Vec2Processor.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True + ) + + def get_logits(model, input_features): + model = model.to(torch_device) + batch = processor( + input_features, + padding=True, + sampling_rate=processor.feature_extractor.sampling_rate, + return_tensors="pt", + ) + + with torch.no_grad(): + logits = model( + input_values=batch["input_values"].to(torch_device), + attention_mask=batch["attention_mask"].to(torch_device), + ).logits + return logits + + input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]] + + model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2", adapter_attn_dim=16) + + with tempfile.TemporaryDirectory() as tempdir: + model.save_pretrained(tempdir) + model = Wav2Vec2ForCTC.from_pretrained(tempdir) + + logits = get_logits(model, input_features) + adapter_weights = model._adapters + + # save safe weights + safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng")) + safe_save_file(adapter_weights, safe_filepath, metadata={"format": "pt"}) + + model.load_adapter("eng") + model.load_adapter("eng", use_safetensors=True) + + with self.assertRaises(OSError): + model.load_adapter("eng", use_safetensors=False) + with self.assertRaises(Exception): + model.load_adapter("ita", use_safetensors=True) + logits_2 = get_logits(model, input_features) + + self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3)) + + with tempfile.TemporaryDirectory() as tempdir: + model.save_pretrained(tempdir) + model = Wav2Vec2ForCTC.from_pretrained(tempdir) + + logits = get_logits(model, input_features) + adapter_weights = model._adapters + + # save pt weights + pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng")) + torch.save(adapter_weights, pt_filepath) + + model.load_adapter("eng") + model.load_adapter("eng", use_safetensors=False) + + with self.assertRaises(OSError): + model.load_adapter("eng", use_safetensors=True) + + logits_2 = get_logits(model, input_features) + + self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3)) + + model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter") + logits = get_logits(model, input_features) + + model.load_adapter("eng") + model.load_adapter("eng", use_safetensors=False) + model.load_adapter("eng", use_safetensors=True) + + logits_2 = get_logits(model, input_features) + + self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3)) + @slow def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") @@ -1768,3 +1865,45 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): # TODO: update the tolerance after the CI moves to torch 1.10 self.assertAlmostEqual(outputs.loss.item(), 17.7963, 2) + + @require_torchaudio + def test_inference_mms_1b_all(self): + model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(torch_device) + processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all") + + LANG_MAP = {"it": "ita", "es": "spa", "fr": "fra", "en": "eng"} + + def run_model(lang): + ds = load_dataset("common_voice", lang, split="test", streaming=True) + sample = next(iter(ds)) + + wav2vec2_lang = LANG_MAP[lang] + + model.load_adapter(wav2vec2_lang) + processor.tokenizer.set_target_lang(wav2vec2_lang) + + resampled_audio = torchaudio.functional.resample( + torch.tensor(sample["audio"]["array"]), 48_000, 16_000 + ).numpy() + + inputs = processor(resampled_audio, sampling_rate=16_000, return_tensors="pt") + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask).logits + + ids = torch.argmax(outputs, dim=-1)[0] + + transcription = processor.decode(ids) + return transcription + + TRANSCRIPTIONS = { + "it": "mi hanno fatto un'offerta che non potevo proprio rifiutare", + "es": "bien y qué regalo vas a abrir primero", + "fr": "un vrai travail intéressant va enfin être mené sur ce sujet", + "en": "twas the time of day and olof spen slept during the summer", + } + + for lang in LANG_MAP.keys(): + assert run_model(lang) == TRANSCRIPTIONS[lang] diff --git a/tests/models/wav2vec2/test_tokenization_wav2vec2.py b/tests/models/wav2vec2/test_tokenization_wav2vec2.py index 9715680e27..9bfae65f6c 100644 --- a/tests/models/wav2vec2/test_tokenization_wav2vec2.py +++ b/tests/models/wav2vec2/test_tokenization_wav2vec2.py @@ -772,3 +772,48 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): output = tokenizer.convert_tokens_to_string(tokens) self.assertIsInstance(output["text"], str) + + def test_nested_vocab(self): + eng_vocab = {"a": 7, "b": 8} + spa_vocab = {"a": 23, "c": 88} + ita_vocab = {"a": 6, "d": 9} + + nested_vocab = {"eng": eng_vocab, "spa": spa_vocab, "ita": ita_vocab} + + def check_tokenizer(tokenizer, check_ita_first=False): + if check_ita_first: + self.assertEqual(tokenizer.decode([6, 9, 9]), "ad") + self.assertEqual(tokenizer.encoder, ita_vocab) + tokenizer.set_target_lang("eng") + + self.assertEqual(tokenizer.encoder, eng_vocab) + self.assertEqual(tokenizer.decode([7, 8, 7]), "aba") + + tokenizer.set_target_lang("spa") + self.assertEqual(tokenizer.decode([23, 88, 23]), "aca") + self.assertEqual(tokenizer.encoder, spa_vocab) + + tokenizer.set_target_lang("eng") + self.assertEqual(tokenizer.encoder, eng_vocab) + self.assertEqual(tokenizer.decode([7, 7, 8]), "ab") + + tokenizer.set_target_lang("ita") + self.assertEqual(tokenizer.decode([6, 9, 9]), "ad") + self.assertEqual(tokenizer.encoder, ita_vocab) + + with tempfile.TemporaryDirectory() as tempdir: + tempfile_path = os.path.join(tempdir, "vocab.json") + with open(tempfile_path, "w") as temp_file: + json.dump(nested_vocab, temp_file) + + tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir, target_lang="eng") + + check_tokenizer(tokenizer) + + with tempfile.TemporaryDirectory() as tempdir: + # should have saved target lang as "ita" since it was last one + tokenizer.save_pretrained(tempdir) + tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir) + + self.assertEqual(tokenizer.target_lang, "ita") + check_tokenizer(tokenizer, check_ita_first=True)