Compare commits

...

1 Commits

Author SHA1 Message Date
Joshua Lochner 315dfc70e2 Add support for hubert models 2023-12-11 17:59:22 +02:00
6 changed files with 93 additions and 7 deletions

View File

@ -296,6 +296,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki.
1. **[GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode)** (from BigCode) released with the paper [SantaCoder: don't reach for the stars!](https://arxiv.org/abs/2301.03988) by Loubna Ben Allal, Raymond Li, Denis Kocetkov, Chenghao Mou, Christopher Akiki, Carlos Munoz Ferrandis, Niklas Muennighoff, Mayank Mishra, Alex Gu, Manan Dey, Logesh Kumar Umapathi, Carolyn Jane Anderson, Yangtian Zi, Joel Lamy Poirier, Hailey Schoelkopf, Sergey Troshin, Dmitry Abulkhanov, Manuel Romero, Michael Lappert, Francesco De Toni, Bernardo García del Río, Qian Liu, Shamik Bose, Urvashi Bhattacharyya, Terry Yue Zhuo, Ian Yu, Paulo Villegas, Marco Zocca, Sourab Mangrulkar, David Lansky, Huu Nguyen, Danish Contractor, Luis Villa, Jia Li, Dzmitry Bahdanau, Yacine Jernite, Sean Hughes, Daniel Fried, Arjun Guha, Harm de Vries, Leandro von Werra.
1. **[HerBERT](https://huggingface.co/docs/transformers/model_doc/herbert)** (from Allegro.pl, AGH University of Science and Technology) released with the paper [KLEJ: Comprehensive Benchmark for Polish Language Understanding](https://www.aclweb.org/anthology/2020.acl-main.111.pdf) by Piotr Rybak, Robert Mroczkowski, Janusz Tracz, Ireneusz Gawlik.
1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.

View File

@ -32,6 +32,7 @@
1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki.
1. **[GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode)** (from BigCode) released with the paper [SantaCoder: don't reach for the stars!](https://arxiv.org/abs/2301.03988) by Loubna Ben Allal, Raymond Li, Denis Kocetkov, Chenghao Mou, Christopher Akiki, Carlos Munoz Ferrandis, Niklas Muennighoff, Mayank Mishra, Alex Gu, Manan Dey, Logesh Kumar Umapathi, Carolyn Jane Anderson, Yangtian Zi, Joel Lamy Poirier, Hailey Schoelkopf, Sergey Troshin, Dmitry Abulkhanov, Manuel Romero, Michael Lappert, Francesco De Toni, Bernardo García del Río, Qian Liu, Shamik Bose, Urvashi Bhattacharyya, Terry Yue Zhuo, Ian Yu, Paulo Villegas, Marco Zocca, Sourab Mangrulkar, David Lansky, Huu Nguyen, Danish Contractor, Luis Villa, Jia Li, Dzmitry Bahdanau, Yacine Jernite, Sean Hughes, Daniel Fried, Arjun Guha, Harm de Vries, Leandro von Werra.
1. **[HerBERT](https://huggingface.co/docs/transformers/model_doc/herbert)** (from Allegro.pl, AGH University of Science and Technology) released with the paper [KLEJ: Comprehensive Benchmark for Polish Language Understanding](https://www.aclweb.org/anthology/2020.acl-main.111.pdf) by Piotr Rybak, Robert Mroczkowski, Janusz Tracz, Ireneusz Gawlik.
1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.

View File

@ -90,6 +90,7 @@ MODEL_SPECIFIC_QUANTIZE_PARAMS = {
MODELS_WITHOUT_TOKENIZERS = [
'wav2vec2',
'wavlm',
'hubert',
]
@ -307,7 +308,7 @@ def main():
**get_main_export_kwargs(config, "automatic-speech-recognition")
)
elif config.model_type == 'wav2vec2':
elif config.model_type in ('wav2vec2', 'hubert'):
if tokenizer is not None:
from .extra.wav2vec2 import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(tokenizer)

View File

@ -421,6 +421,22 @@ SUPPORTED_MODELS = {
'allegro/herbert-large-cased',
],
},
'hubert': {
# Feature extraction
'feature-extraction': [
'facebook/hubert-base-ls960',
],
# Audio classification
'audio-classification': [
'superb/hubert-base-superb-ks',
],
# Automatic speech recognition
'automatic-speech-recognition': [
'facebook/hubert-large-ls960-ft',
],
},
'llama': {
# Text generation
'text-generation': [

View File

@ -22,7 +22,7 @@
*
* We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example:
*
* **Example:** Load and run a `AutoModelForSeq2SeqLM`.
* **Example:** Load and run an `AutoModelForSeq2SeqLM`.
* ```javascript
* import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@xenova/transformers';
*
@ -2477,7 +2477,7 @@ export class ASTModel extends ASTPreTrainedModel { }
* Audio Spectrogram Transformer model with an audio classification head on top
* (a linear layer on top of the pooled output) e.g. for datasets like AudioSet, Speech Commands v2.
*/
export class ASTForAudioClassification extends ASTPreTrainedModel {}
export class ASTForAudioClassification extends ASTPreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@ -3736,7 +3736,7 @@ export class Wav2Vec2PreTrainedModel extends PreTrainedModel { };
/**
* The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.
*
* **Example:** Load and run an `Wav2Vec2Model` for feature extraction.
* **Example:** Load and run a `Wav2Vec2Model` for feature extraction.
*
* ```javascript
* import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
@ -3782,6 +3782,69 @@ export class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Hubert models
export class HubertPreTrainedModel extends PreTrainedModel { }
/**
* The bare Hubert Model transformer outputting raw hidden-states without any specific head on top.
*
* **Example:** Load and run a `HubertModel` for feature extraction.
*
* ```javascript
* import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/hubert-base-ls960');
* const audio = await read_audio('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav', 16000);
* const inputs = await processor(audio);
*
* // Load and run model with inputs
* const model = await AutoModel.from_pretrained('Xenova/hubert-base-ls960');
* const output = await model(inputs);
* // {
* // last_hidden_state: Tensor {
* // dims: [ 1, 549, 768 ],
* // type: 'float32',
* // data: Float32Array(421632) [0.0682469978928566, 0.08104046434164047, -0.4975186586380005, ...],
* // size: 421632
* // }
* // }
* ```
*/
export class HubertModel extends Wav2Vec2PreTrainedModel { }
/**
* Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
*/
export class HubertForCTC extends Wav2Vec2PreTrainedModel {
/**
* @param {Object} model_inputs
* @param {Tensor} model_inputs.input_values Float values of input raw speech waveform.
* @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1]
*/
async _call(model_inputs) {
return new CausalLMOutput(await super._call(model_inputs));
}
}
/**
* Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB Keyword Spotting.
*/
export class HubertForSequenceClassification extends Wav2Vec2PreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
*/
async _call(model_inputs) {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// WavLM models
/**
@ -3792,7 +3855,7 @@ export class WavLMPreTrainedModel extends PreTrainedModel { };
/**
* The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.
*
* **Example:** Load and run an `WavLMModel` for feature extraction.
* **Example:** Load and run a `WavLMModel` for feature extraction.
*
* ```javascript
* import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
@ -4285,6 +4348,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['mobilebert', ['MobileBertModel', MobileBertModel]],
['squeezebert', ['SqueezeBertModel', SqueezeBertModel]],
['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]],
['hubert', ['HubertModel', HubertModel]],
['wavlm', ['WavLMModel', WavLMModel]],
['audio-spectrogram-transformer', ['ASTModel', ASTModel]],
@ -4474,11 +4538,13 @@ const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([
const MODEL_FOR_CTC_MAPPING_NAMES = new Map([
['wav2vec2', ['Wav2Vec2ForCTC', Wav2Vec2ForCTC]],
['wavlm', ['WavLMForCTC', WavLMForCTC]],
['hubert', ['HubertForCTC', HubertForCTC]],
]);
const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([
['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]],
['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]],
['hubert', ['HubertForSequenceClassification', HubertForSequenceClassification]],
['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]],
]);

View File

@ -1276,6 +1276,7 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
case 'whisper':
return this._call_whisper(audio, kwargs)
case 'wav2vec2':
case 'hubert':
return this._call_wav2vec2(audio, kwargs)
default:
throw new Error(`AutomaticSpeechRecognitionPipeline does not support model type '${this.model.config.model_type}'.`)