Compare commits

...

25 Commits

Author SHA1 Message Date
Joshua Lochner a8ab8e8eb6 Add WIP conversion scripts
Will be updated once https://github.com/huggingface/optimum/pull/1552 is merged
2023-12-04 21:40:42 +02:00
Joshua Lochner 4cabee6ed8 Optimizations 2023-12-04 20:54:30 +02:00
Joshua Lochner 1136fc0045 Fix jsdoc 2023-12-03 22:08:35 +02:00
Joshua Lochner 14ad6a42f7 Optimizations 2023-12-03 22:08:29 +02:00
Joshua Lochner 956ac0ef32 Cleanup 2023-12-03 20:43:00 +02:00
Joshua Lochner 14959112e1 Update mel filters unit test 2023-12-03 20:18:09 +02:00
Joshua Lochner bb0d55bdcb Optimize `mel_filter_bank` computation
-30ms
2023-12-03 20:17:50 +02:00
Joshua Lochner e7e5a46740 Move audio validation to helper function 2023-12-02 17:16:49 +02:00
Joshua Lochner 9fa43597b0 Add `ClapAudioModelWithProjection` and `ClapTextModelWithProjection` 2023-12-01 22:20:36 +02:00
Joshua Lochner 3a4c71fee1 Add `'Xenova/tiny-random-ClapModel'` 2023-12-01 20:56:37 +02:00
Joshua Lochner c6342bef38 Update `mel_filter_bank` unit test 2023-12-01 20:56:22 +02:00
Joshua Lochner 934a91027c `let` -> `const` 2023-12-01 20:41:42 +02:00
Joshua Lochner 1ac561dcf2 Cleanup 2023-12-01 20:38:42 +02:00
Joshua Lochner 21d0329a85 Add listed support for `zero-shot-audio-classification` pipeline tag 2023-12-01 17:59:55 +02:00
Joshua Lochner 17e264e4cb Add `ZeroShotAudioClassificationPipeline` 2023-12-01 17:54:14 +02:00
Joshua Lochner e18f41ea04 Add support for `CLAP` 2023-11-30 22:54:48 +02:00
Joshua Lochner beafc3bcb7 Implement `ClapFeatureExtractor` unit tests 2023-11-30 22:51:33 +02:00
Joshua Lochner 2052e45129 Add `ClapFeatureExtractor` 2023-11-30 22:51:18 +02:00
Joshua Lochner e8b8c7378a Implement `log_mel='dB'` in `spectrogram` function 2023-11-30 22:50:57 +02:00
Joshua Lochner bd600538e1 Add audio processing unit tests 2023-11-30 15:52:10 +02:00
Joshua Lochner feff6dd28f Add another audio-classification example 2023-11-30 15:51:45 +02:00
Joshua Lochner 9ff76040fa Add support for AST models 2023-11-30 15:51:32 +02:00
Joshua Lochner b40fab5e99 Refactor audio processors 2023-11-30 15:20:00 +02:00
Joshua Lochner af84ec90c3 Refactor maths.js and audio.js 2023-11-30 15:09:25 +02:00
Joshua Lochner 2803cdd229 Add FFT unit tests 2023-11-30 11:37:39 +02:00
15 changed files with 1673 additions and 486 deletions

View File

@ -246,6 +246,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
| [Image-to-Text](https://huggingface.co/tasks/image-to-text) | `image-to-text` | Output text from a given image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToTextPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-to-text&library=transformers.js) |
| [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ |
| [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ |
| [Zero-Shot Audio Classification](https://huggingface.co/learn/audio-course/chapter4/classification_models#zero-shot-audio-classification) | `zero-shot-audio-classification` | Classifying audios into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotAudioClassificationPipeline)<br>[(models)](https://huggingface.co/models?other=zero-shot-audio-classification&library=transformers.js) |
| [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) |
| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)<br>[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) |
@ -261,6 +262,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
### Models
1. **[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
1. **[Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer)** (from MIT) released with the paper [AST: Audio Spectrogram Transformer](https://arxiv.org/abs/2104.01778) by Yuan Gong, Yu-An Chung, James Glass.
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei.
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
@ -268,6 +270,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
1. **[CLAP](https://huggingface.co/docs/transformers/model_doc/clap)** (from LAION-AI) released with the paper [Large-scale Contrastive Language-Audio Pretraining with Feature Fusion and Keyword-to-Caption Augmentation](https://arxiv.org/abs/2211.06687) by Yusong Wu, Ke Chen, Tianyu Zhang, Yuchen Hui, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
1. **[CodeLlama](https://huggingface.co/docs/transformers/model_doc/llama_code)** (from MetaAI) released with the paper [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve.

View File

@ -58,6 +58,7 @@
| [Image-to-Text](https://huggingface.co/tasks/image-to-text) | `image-to-text` | Output text from a given image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToTextPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-to-text&library=transformers.js) |
| [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ |
| [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ |
| [Zero-Shot Audio Classification](https://huggingface.co/learn/audio-course/chapter4/classification_models#zero-shot-audio-classification) | `zero-shot-audio-classification` | Classifying audios into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotAudioClassificationPipeline)<br>[(models)](https://huggingface.co/models?other=zero-shot-audio-classification&library=transformers.js) |
| [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) |
| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)<br>[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) |

View File

@ -2,6 +2,7 @@
### Models
1. **[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
1. **[Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer)** (from MIT) released with the paper [AST: Audio Spectrogram Transformer](https://arxiv.org/abs/2104.01778) by Yuan Gong, Yu-An Chung, James Glass.
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei.
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
@ -9,6 +10,7 @@
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
1. **[CLAP](https://huggingface.co/docs/transformers/model_doc/clap)** (from LAION-AI) released with the paper [Large-scale Contrastive Language-Audio Pretraining with Feature Fusion and Keyword-to-Caption Augmentation](https://arxiv.org/abs/2211.06687) by Yusong Wu, Ke Chen, Tianyu Zhang, Yuchen Hui, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
1. **[CodeLlama](https://huggingface.co/docs/transformers/model_doc/llama_code)** (from MetaAI) released with the paper [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve.

View File

@ -353,6 +353,25 @@ def main():
device=conv_args.device,
)
# TODO: Enable once https://github.com/huggingface/optimum/pull/1552 is merged
# elif config.model_type == 'clap' and conv_args.split_modalities:
# # Handle special case for exporting text and audio models separately
# from .extra.clap import ClapTextModelWithProjectionOnnxConfig, ClapAudioModelWithProjectionOnnxConfig
# from transformers.models.clap import ClapTextModelWithProjection, ClapAudioModelWithProjection
# text_model = ClapTextModelWithProjection.from_pretrained(model_id)
# audio_model = ClapAudioModelWithProjection.from_pretrained(model_id)
# export_models(
# models_and_onnx_configs={
# "text_model": (text_model, ClapTextModelWithProjectionOnnxConfig(text_model.config)),
# "audio_model": (audio_model, ClapAudioModelWithProjectionOnnxConfig(audio_model.config)),
# },
# output_dir=output_model_folder,
# opset=conv_args.opset,
# device=conv_args.device,
# )
else:
main_export(**export_kwargs)

40
scripts/extra/clap.py Normal file
View File

@ -0,0 +1,40 @@
# TODO: Enable once https://github.com/huggingface/optimum/pull/1552 is merged
# # Support exporting vision and text models separately:
# # Adapted from https://github.com/huggingface/optimum/issues/1186#issuecomment-1637641760
# from optimum.exporters.onnx.model_configs import CLAPTextWithProjectionOnnxConfig, AudioOnnxConfig
# from optimum.utils.normalized_config import NormalizedAudioConfig
# from optimum.utils.input_generators import DummyAudioInputGenerator
# from typing import Dict
# class ClapAudioModelWithProjectionOnnxConfig(AudioOnnxConfig):
# NORMALIZED_CONFIG_CLASS = NormalizedAudioConfig
# DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator, )
# @property
# def inputs(self) -> Dict[str, Dict[int, str]]:
# return {
# "input_features": {0: "audio_batch_size", 1: "num_channels", 2: "height", 3: "width"}, # As described in modeling_clap.py
# }
# @property
# def outputs(self) -> Dict[str, Dict[int, str]]:
# return {
# "audio_embeds": {0: "batch_size"},
# }
# class ClapTextModelWithProjectionOnnxConfig(CLAPTextWithProjectionOnnxConfig):
# @property
# def outputs(self) -> Dict[str, Dict[int, str]]:
# return {
# "text_embeds": {0: "batch_size"},
# }
# def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
# dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
# if framework == "pt":
# import torch
# dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int64)
# return dummy_inputs

View File

@ -3,6 +3,13 @@ from .extra.marian import SUPPORTED_HELSINKI_NLP_MODELS
SUPPORTED_MODELS = {
# NOTE: keys of `SUPPORTED_MODELS` are subsets of https://github.com/huggingface/optimum/blob/7f8e606689365931300ef5e6d3b20cb88771cb08/optimum/exporters/tasks.py#L281-L965
'audio-spectrogram-transformer': [
'MIT/ast-finetuned-audioset-10-10-0.4593',
'MIT/ast-finetuned-audioset-16-16-0.442',
'MIT/ast-finetuned-speech-commands-v2',
'mtg-upf/discogs-maest-30s-pw-73e-ts',
],
'albert': [
# Masked language modelling
'albert-base-v2',
@ -126,6 +133,14 @@ SUPPORTED_MODELS = {
'camembert-base',
'airesearch/wangchanberta-base-att-spm-uncased',
],
'clap': [
# Zero-shot audio classification and feature extraction
# (with and without `--split_modalities`)
'laion/clap-htsat-unfused',
# TODO add 'laion/clap-htsat-fused',
'Xenova/tiny-random-ClapModel',
],
'clip': [
# Zero-shot image classification and feature extraction
# (with and without `--split_modalities`)

View File

@ -2464,6 +2464,22 @@ export class XLMRobertaForQuestionAnswering extends XLMRobertaPreTrainedModel {
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Audio Spectrogram Transformer (AST) models
export class ASTPreTrainedModel extends PreTrainedModel { };
/**
* The bare AST Model transformer outputting raw hidden-states without any specific head on top.
*/
export class ASTModel extends ASTPreTrainedModel { }
/**
* Audio Spectrogram Transformer model with an audio classification head on top
* (a linear layer on top of the pooled output) e.g. for datasets like AudioSet, Speech Commands v2.
*/
export class ASTForAudioClassification extends ASTPreTrainedModel {}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Whisper models
export class WhisperPreTrainedModel extends PreTrainedModel { };
@ -4022,6 +4038,85 @@ export class FalconForCausalLM extends FalconPreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// CLAP models
export class ClapPreTrainedModel extends PreTrainedModel { }
export class ClapModel extends ClapPreTrainedModel { }
/**
* CLAP Text Model with a projection layer on top (a linear layer on top of the pooled output).
*
* **Example:** Compute text embeddings with `ClapTextModelWithProjection`.
*
* ```javascript
* import { AutoTokenizer, ClapTextModelWithProjection } from '@xenova/transformers';
*
* // Load tokenizer and text model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clap-htsat-unfused');
* const text_model = await ClapTextModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused');
*
* // Run tokenization
* const texts = ['a sound of a cat', 'a sound of a dog'];
* const text_inputs = tokenizer(texts, { padding: true, truncation: true });
*
* // Compute embeddings
* const { text_embeds } = await text_model(text_inputs);
* // Tensor {
* // dims: [ 2, 512 ],
* // type: 'float32',
* // data: Float32Array(1024) [ ... ],
* // size: 1024
* // }
* ```
*/
export class ClapTextModelWithProjection extends ClapPreTrainedModel {
/** @type {PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
}
}
/**
* CLAP Audio Model with a projection layer on top (a linear layer on top of the pooled output).
*
* **Example:** Compute audio embeddings with `ClapAudioModelWithProjection`.
*
* ```javascript
* import { AutoProcessor, ClapAudioModelWithProjection, read_audio } from '@xenova/transformers';
*
* // Load processor and audio model
* const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused');
* const audio_model = await ClapAudioModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused');
*
* // Read audio and run processor
* const audio = await read_audio('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cat_meow.wav');
* const audio_inputs = await processor(audio);
*
* // Compute embeddings
* const { audio_embeds } = await audio_model(audio_inputs);
* // Tensor {
* // dims: [ 1, 512 ],
* // type: 'float32',
* // data: Float32Array(512) [ ... ],
* // size: 512
* // }
* ```
*/
export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
/** @type {PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'audio_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// AutoModels, used to simplify construction of PreTrainedModels
// (uses config to instantiate correct class)
@ -4102,11 +4197,13 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['roberta', ['RobertaModel', RobertaModel]],
['xlm', ['XLMModel', XLMModel]],
['xlm-roberta', ['XLMRobertaModel', XLMRobertaModel]],
['clap', ['ClapModel', ClapModel]],
['clip', ['CLIPModel', CLIPModel]],
['mobilebert', ['MobileBertModel', MobileBertModel]],
['squeezebert', ['SqueezeBertModel', SqueezeBertModel]],
['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]],
['wavlm', ['WavLMModel', WavLMModel]],
['audio-spectrogram-transformer', ['ASTModel', ASTModel]],
['detr', ['DetrModel', DetrModel]],
['vit', ['ViTModel', ViTModel]],
@ -4295,7 +4392,10 @@ const MODEL_FOR_CTC_MAPPING_NAMES = new Map([
const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([
['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]],
['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]],
]);
['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]],
]);
const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([
['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]],
@ -4343,6 +4443,9 @@ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
const CUSTOM_MAPPING = [
['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection, MODEL_TYPES.EncoderOnly],
['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly],
['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly],
]
for (const [name, model, type] of CUSTOM_MAPPING) {
MODEL_TYPE_MAPPING.set(name, type);

View File

@ -954,7 +954,7 @@ export class FeatureExtractionPipeline extends Pipeline {
* Audio classification pipeline using any `AutoModelForAudioClassification`.
* This pipeline predicts the class of a raw waveform or an audio file.
*
* **Example:** Perform audio classification.
* **Example:** Perform audio classification with `Xenova/wav2vec2-large-xlsr-53-gender-recognition-librispeech`.
* ```javascript
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
* let classifier = await pipeline('audio-classification', 'Xenova/wav2vec2-large-xlsr-53-gender-recognition-librispeech');
@ -964,6 +964,19 @@ export class FeatureExtractionPipeline extends Pipeline {
* // { label: 'female', score: 0.001845747814513743 }
* // ]
* ```
*
* **Example:** Perform audio classification with `Xenova/ast-finetuned-audioset-10-10-0.4593` and return top 4 results.
* ```javascript
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cat_meow.wav';
* let classifier = await pipeline('audio-classification', 'Xenova/ast-finetuned-audioset-10-10-0.4593');
* let output = await classifier(url, { topk: 4 });
* // [
* // { label: 'Meow', score: 0.5617874264717102 },
* // { label: 'Cat', score: 0.22365376353263855 },
* // { label: 'Domestic animals, pets', score: 0.1141069084405899 },
* // { label: 'Animal', score: 0.08985692262649536 },
* // ]
* ```
*/
export class AudioClassificationPipeline extends Pipeline {
@ -1039,6 +1052,105 @@ export class AudioClassificationPipeline extends Pipeline {
}
}
/**
* Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
* provide an audio and a set of `candidate_labels`.
*
* **Example**: Perform zero-shot audio classification with `Xenova/clap-htsat-unfused`.
* ```javascript
* let classifier = await pipeline('zero-shot-audio-classification', 'Xenova/clap-htsat-unfused');
* let audio = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/dog_barking.wav';
* let candidate_labels = ['dog', 'vaccum cleaner'];
* let scores = await classifier(audio, candidate_labels);
* // [
* // { score: 0.9993992447853088, label: 'dog' },
* // { score: 0.0006007603369653225, label: 'vaccum cleaner' }
* // ]
* ```
*/
export class ZeroShotAudioClassificationPipeline extends Pipeline {
/**
* Create a new ZeroShotAudioClassificationPipeline.
* @param {Object} options An object containing the following properties:
* @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks.
* @param {PreTrainedModel} [options.model] The model to use.
* @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use.
* @param {Processor} [options.processor] The processor to use.
*/
constructor(options) {
super(options);
}
/**
* Preprocesses the input audio for the ZeroShotAudioClassificationPipeline.
* @param {any} audio The audio to be preprocessed.
* @param {number} sampling_rate The sampling rate of the audio.
* @returns {Promise<Float32Array>} A promise that resolves to the preprocessed audio data.
* @private
*/
async _preprocess(audio, sampling_rate) {
if (isString(audio)) {
audio = await read_audio(audio, sampling_rate);
}
return audio;
}
/**
* Assign labels to the audio(s) passed as inputs.
* @param {Array} audios The input audios.
* @param {string[]} candidate_labels The candidate labels for this audio
* @param {Object} options The options for the classification.
* @param {string} [options.hypothesis_template] The sentence used in cunjunction with *candidate_labels* to attempt
* the audio classification by replacing the placeholder with the candidate_labels.
* Then likelihood is estimated by using logits_per_audio.
* @returns {Promise<any>}
*/
async _call(audios, candidate_labels, {
hypothesis_template = "This is a sound of {}."
} = {}) {
const single = !Array.isArray(audios);
if (single) {
// @ts-ignore
audios = [audios];
}
// Insert label into hypothesis template
const texts = candidate_labels.map(
x => hypothesis_template.replace('{}', x)
);
// Run tokenization
const text_inputs = this.tokenizer(texts, {
padding: true,
truncation: true,
});
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
const toReturn = [];
for (let audio of audios) {
audio = await this._preprocess(audio, sampling_rate)
const audio_inputs = await this.processor(audio);
// Run model with both text and audio inputs
const output = await this.model({ ...text_inputs, ...audio_inputs });
// Compute softmax per audio
const probs = softmax(output.logits_per_audio.data);
toReturn.push([...probs].map((x, i) => {
return {
score: x,
label: candidate_labels[i]
}
}));
}
return !single ? toReturn : toReturn[0];
}
}
/**
* Pipeline that aims at extracting spoken text contained within some audio.
@ -2272,6 +2384,18 @@ const SUPPORTED_TASKS = {
},
"type": "audio",
},
"zero-shot-audio-classification": {
"tokenizer": AutoTokenizer,
"pipeline": ZeroShotAudioClassificationPipeline,
"model": AutoModel,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "laion/clap-htsat-fused",
"model": "Xenova/clap-htsat-unfused",
},
"type": "multimodal",
},
"automatic-speech-recognition": {
"tokenizer": AutoTokenizer,
"pipeline": AutomaticSpeechRecognitionPipeline,
@ -2459,6 +2583,7 @@ const TASK_ALIASES = {
* - `"translation"`: will return a `TranslationPipeline`.
* - `"translation_xx_to_yy"`: will return a `TranslationPipeline`.
* - `"zero-shot-classification"`: will return a `ZeroShotClassificationPipeline`.
* - `"zero-shot-audio-classification"`: will return a `ZeroShotAudioClassificationPipeline`.
* - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`.
* - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`.
* @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used.

View File

@ -33,14 +33,17 @@ import {
min,
max,
softmax,
FFT,
} from './utils/maths.js';
import { Tensor, transpose, cat, interpolate } from './utils/tensor.js';
import { RawImage } from './utils/image.js';
import { getMelFilters } from './utils/audio.js';
import {
window_function,
spectrogram,
mel_filter_bank,
} from './utils/audio.js';
// Helper functions
@ -146,6 +149,21 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes =
* @typedef {[height: number, width: number]} HeightWidth
*/
/**
* Helper function to validate audio inputs.
* @param {any} audio The audio data.
* @param {string} feature_extractor The name of the feature extractor.
* @private
*/
function validate_audio_inputs(audio, feature_extractor) {
if (!(audio instanceof Float32Array || audio instanceof Float64Array)) {
throw new Error(
`${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` +
`If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.`
)
}
}
/**
* Base class for feature extractors.
*
@ -1126,232 +1144,24 @@ export class Swin2SRImageProcessor extends ImageFeatureExtractor {
}
}
export class WhisperFeatureExtractor extends FeatureExtractor {
constructor(config) {
super(config);
// Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist.
this.config.mel_filters ??= getMelFilters(this.config.sampling_rate, this.config.n_fft, this.config.feature_size);
}
this.config.mel_filters ??= mel_filter_bank(
Math.floor(1 + this.config.n_fft / 2), // num_frequency_bins
this.config.feature_size, // num_mel_filters
0.0, // min_frequency
8000.0, // max_frequency
this.config.sampling_rate, // sampling_rate
"slaney", // norm
"slaney", // mel_scale
);
/**
* Pads an array with a reflected version of itself on both ends.
* @param {Float32Array} array The array to pad.
* @param {number} left The amount of padding to add to the left.
* @param {number} right The amount of padding to add to the right.
* @returns {Float32Array} The padded array.
*/
padReflect(array, left, right) {
const padded = new Float32Array(array.length + left + right);
const w = array.length - 1;
for (let i = 0; i < array.length; ++i) {
padded[left + i] = array[i];
}
for (let i = 1; i <= left; ++i) {
padded[left - i] = array[calculateReflectOffset(i, w)];
}
for (let i = 1; i <= right; ++i) {
padded[w + left + i] = array[calculateReflectOffset(w - i, w)];
}
return padded;
}
/**
* Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
*
* @param {number[][]} frames A 2D array representing the signal frames.
* @param {number[]} window A 1D array representing the window to be applied to the frames.
* @returns {Object} An object with the following properties:
* - data: A 1D array representing the complex STFT of the signal.
* - dims: An array representing the dimensions of the STFT data, i.e. [num_frames, num_fft_bins].
*/
stft(frames, window) {
// Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
//
// NOTE: Since the window width is not a power of 2, we must
// perform Fast Fourier Transform with chirp-z transform:
// https://math.stackexchange.com/questions/77118/non-power-of-2-ffts/77156#77156
// Helper variables
const fft_size = this.config.n_fft;
const a = 2 * (fft_size - 1);
const b = 2 * (2 * fft_size - 1);
const nextP2 = 2 ** (Math.ceil(Math.log2(b)))
const num_fft_bins = fft_size + 2;
// Preallocate array to store output
// double since we store complex numbers
const data = new Float32Array(num_fft_bins * frames.length);
// Define buffers
// Compute chirp for transform
const chirp = new Float32Array(b);
const ichirp = new Float32Array(nextP2);
const buffer1 = new Float32Array(nextP2);
const buffer2 = new Float32Array(nextP2);
const outBuffer = new Float32Array(nextP2);
const outBuffer2 = new Float32Array(nextP2);
const outBuffer3 = new Float32Array(nextP2);
// Compute complex exponentiation
const theta = -2 * Math.PI / fft_size;
const baseR = Math.cos(theta);
const baseI = Math.sin(theta);
// Precompute helper for chirp-z transform
for (let i = 0; i < b >> 1; ++i) {
// Compute complex power:
const e = (i + 1 - fft_size) ** 2 / 2.0;
// Compute the modulus and argument of the result
const result_mod = Math.sqrt(baseR ** 2 + baseI ** 2) ** e;
const result_arg = e * Math.atan2(baseI, baseR);
// Convert the result back to rectangular form
// and assign to chirp and ichirp
let i2 = 2 * i;
chirp[i2] = result_mod * Math.cos(result_arg);
chirp[i2 + 1] = result_mod * Math.sin(result_arg);
// conjugate
ichirp[i2] = chirp[i2];
ichirp[i2 + 1] = - chirp[i2 + 1];
}
const slicedChirp = chirp.subarray(a, b);
// create object to perform Fast Fourier Transforms
// with `nextP2` complex numbers
const f = new FFT(nextP2 >> 1);
// TODO: decide between Float32Array and Float64Array
f.transform(outBuffer, ichirp);
for (let i = 0; i < frames.length; ++i) {
const frame = frames[i];
for (let j = 0; j < slicedChirp.length; j += 2) {
const j2 = j + 1
const j3 = j >> 1;
const a_real = frame[j3] * window[j3];
buffer1[j] = a_real * slicedChirp[j];
buffer1[j2] = a_real * slicedChirp[j2];
}
// TODO: decide between Float32Array and Float64Array
f.transform(outBuffer2, buffer1);
for (let j = 0; j < outBuffer.length; j += 2) {
const j2 = j + 1;
buffer2[j] = outBuffer2[j] * outBuffer[j] - outBuffer2[j2] * outBuffer[j2]
buffer2[j2] = outBuffer2[j] * outBuffer[j2] + outBuffer2[j2] * outBuffer[j]
}
// TODO: decide between Float32Array and Float64Array
f.inverseTransform(outBuffer3, buffer2)
const offset = i * num_fft_bins;
for (let j = 0; j < num_fft_bins; j += 2) {
const a_real = outBuffer3[j + a];
const a_imag = outBuffer3[j + a + 1];
const b_real = slicedChirp[j];
const b_imag = slicedChirp[j + 1];
// TODO write as transpose
const o1 = offset + j;
data[o1] = a_real * b_real - a_imag * b_imag
data[o1 + 1] = a_real * b_imag + a_imag * b_real
}
}
return {
data: data,
dims: [frames.length, num_fft_bins] // [3001, 402]
};
}
/**
* Creates an array of frames from a given waveform.
*
* @param {Float32Array} waveform The waveform to create frames from.
* @param {boolean} [center=true] Whether to center the frames on their corresponding positions in the waveform. Defaults to true.
* @returns {Array} An array of frames.
*/
fram_wave(waveform, center = true) {
const frames = [];
const half_window = Math.floor((this.config.n_fft - 1) / 2) + 1;
const waveformLength = waveform.length;
for (let i = 0; i < waveformLength + 1; i += this.config.hop_length) {
let frame;
if (center) {
let frameStart = i > half_window ? i - half_window : 0;
let frameEnd =
i < waveformLength - half_window
? i + half_window
: waveformLength;
frame = waveform.subarray(frameStart, frameEnd)
if (frameStart === 0) {
frame = this.padReflect(
frame,
-i + half_window,
0
)
} else if (frameEnd === waveformLength) {
frame = this.padReflect(
frame,
0,
i - waveformLength + half_window
)
}
} else {
frame = new Float32Array(this.config.n_fft);
const frameArray = waveform.subarray(i, i + this.config.n_fft);
if (frameArray.length < this.config.n_fft) {
frame.set(frameArray);
frame.fill(0, frameArray.length, this.config.n_fft)
} else {
frame = frameArray;
}
}
frames.push(frame);
}
return frames;
}
/**
* Generates a Hanning window of length M.
*
* @param {number} M The length of the Hanning window to generate.
* @returns {*} The generated Hanning window.
*/
hanning(M) {
if (M < 1) {
return [];
}
if (M === 1) {
return [1];
}
const denom = M - 1;
const cos_vals = new Float32Array(denom);
for (let i = 0; i < denom; ++i) {
const n = 2 * i - M + 1;
cos_vals[i] = 0.5 + 0.5 * Math.cos(Math.PI * n / denom);
}
return cos_vals;
this.window = window_function(this.config.n_fft, 'hann');
}
/**
@ -1360,80 +1170,28 @@ export class WhisperFeatureExtractor extends FeatureExtractor {
* @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
*/
_extract_fbank_features(waveform) {
// Compute the log-Mel spectrogram of the provided audio
const { data, dims } = spectrogram(
waveform,
this.window, // window
this.config.n_fft, // frame_length
this.config.hop_length, // hop_length
{
power: 2.0,
mel_filters: this.config.mel_filters,
log_mel: 'log10',
const buffer = new Float32Array(this.config.n_samples);
buffer.set(waveform)
const window = this.hanning(this.config.n_fft + 1)
const frames = this.fram_wave(buffer)
const stft = this.stft(frames, window)
const stftData = stft.data;
const d1 = stft.dims[0] - 1; // Ignore last row
const d2 = stft.dims[1] >> 1; // Only need to store real numbers now
// compute magnitudes
// NOTE: Unlike the original implementation, we do not
// transpose since we perform matrix multiplication later
const magnitudes = new Float32Array(d1 * d2);
for (let i = 0; i < d1; ++i) {
for (let j = 0; j < d2; ++j) {
// let outOffset = (j * d1 + i); // transpose
let outOffset = i * d2 + j;
let inOffset = outOffset << 1; // * 2 since complex
let magnitude = stftData[inOffset] ** 2 + stftData[inOffset + 1] ** 2
magnitudes[outOffset] = magnitude;
// Custom
max_num_frames: this.config.nb_max_frames, // 3000
}
)
const maxValue = max(data)[0];
for (let i = 0; i < data.length; ++i) {
data[i] = (Math.max(data[i], maxValue - 8.0) + 4.0) / 4.0;
}
const mel_filters = this.config.mel_filters;
const num_mel_filters = mel_filters.length;
const mel_spec = new Float32Array(num_mel_filters * d1);
let mIndex = 0;
// Perform matrix muliplication:
// mel_spec = filters @ magnitudes
// - filters.shape=(80, 201)
// - magnitudes.shape=(201, 3000)
// - mel_spec.shape=(80, 3000)
for (let i = 0; i < num_mel_filters; ++i) {
const mel_filter = mel_filters[i];
for (let j = 0; j < d1; ++j) {
let sum = 0;
// perform dot product
for (let k = 0; k < d2; ++k) {
sum += mel_filter[k] * magnitudes[j * d2 + k];
}
mel_spec[mIndex++] = sum;
}
}
const a_min = 1e-10;
const log_spec = new Float32Array(mel_spec.length);
let maxLogSpec = 0;
for (let i = 0; i < mel_spec.length; ++i) {
const clipped = Math.max(a_min, mel_spec[i]);
const log10 = Math.log10(clipped);
log_spec[i] = log10;
maxLogSpec = Math.max(log10, maxLogSpec)
}
for (let i = 0; i < log_spec.length; ++i) {
log_spec[i] = Math.max(log_spec[i], maxLogSpec - 8);
log_spec[i] = (log_spec[i] + 4) / 4;
}
return {
data: log_spec,
dims: [num_mel_filters, d1]
};
return { data, dims };
}
/**
@ -1442,29 +1200,28 @@ export class WhisperFeatureExtractor extends FeatureExtractor {
* @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor.
*/
async _call(audio) {
if (!(audio instanceof Float32Array || audio instanceof Float64Array)) {
throw new Error(
// @ts-ignore
`WhisperFeatureExtractor expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` +
`If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.`
)
}
validate_audio_inputs(audio, 'WhisperFeatureExtractor');
let waveform;
if (audio.length > this.config.n_samples) {
console.warn(
"Attempting to extract features for audio longer than 30 seconds. " +
"If using a pipeline to extract transcript from a long audio clip, " +
"remember to specify `chunk_length_s` and/or `stride_length_s`."
);
waveform = audio.slice(0, this.config.n_samples);
} else {
// pad with zeros
waveform = new Float32Array(this.config.n_samples);
waveform.set(audio);
}
let waveform = audio.slice(0, this.config.n_samples);
let features = this._extract_fbank_features(waveform);
const { data, dims } = this._extract_fbank_features(waveform);
return {
input_features: new Tensor('float32',
features.data,
[1, ...features.dims]
data,
[1, ...dims]
)
};
}
@ -1490,14 +1247,8 @@ export class Wav2Vec2FeatureExtractor extends FeatureExtractor {
* @returns {Promise<{ input_values: Tensor; attention_mask: Tensor }>} A Promise resolving to an object containing the extracted input features and attention mask as Tensors.
*/
async _call(audio) {
// TODO: remove duplication
if (!(audio instanceof Float32Array || audio instanceof Float64Array)) {
throw new Error(
// @ts-ignore
`Wav2Vec2FeatureExtractor expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` +
`If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.`
)
}
validate_audio_inputs(audio, 'Wav2Vec2FeatureExtractor');
if (audio instanceof Float64Array) {
audio = new Float32Array(audio);
}
@ -1518,6 +1269,260 @@ export class Wav2Vec2FeatureExtractor extends FeatureExtractor {
}
}
export class ASTFeatureExtractor extends FeatureExtractor {
constructor(config) {
super(config);
const sampling_rate = this.config.sampling_rate;
const mel_filters = mel_filter_bank(
256, // num_frequency_bins
this.config.num_mel_bins, // num_mel_filters
20, // min_frequency
Math.floor(sampling_rate / 2), // max_frequency
sampling_rate, // sampling_rate
null, // norm
"kaldi", // mel_scale
true, // triangularize_in_mel_space
);
// Do padding:
for (let i = 0; i < mel_filters.length; ++i) {
mel_filters[i].push(0);
}
this.mel_filters = mel_filters;
this.window = window_function(400, 'hann', {
periodic: false,
})
this.mean = this.config.mean;
this.std = this.config.std;
}
/**
* Computes the log-Mel spectrogram of the provided audio waveform.
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
* @param {number} max_length The maximum number of frames to return.
* @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
*/
_extract_fbank_features(waveform, max_length) {
// NOTE: We don't pad/truncate since that is passed in as `max_num_frames`
return spectrogram(
waveform,
this.window, // window
400, // frame_length
160, // hop_length
{
fft_length: 512,
power: 2.0,
center: false,
preemphasis: 0.97,
mel_filters: this.mel_filters,
log_mel: 'log',
mel_floor: 1.192092955078125e-07,
remove_dc_offset: true,
// Custom
max_num_frames: max_length,
transpose: true,
}
)
}
/**
* Asynchronously extracts features from a given audio using the provided configuration.
* @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array.
* @returns {Promise<{ input_values: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor.
*/
async _call(audio) {
validate_audio_inputs(audio, 'ASTFeatureExtractor');
const features = this._extract_fbank_features(audio, this.config.max_length);
if (this.config.do_normalize) {
// Normalize the input audio spectrogram to have mean=0, std=0.5
const denom = this.std * 2;
for (let i = 0; i < features.data.length; ++i) {
features.data[i] = (features.data[i] - this.mean) / denom;
}
}
return {
input_values: new Tensor('float32',
features.data,
[1, ...features.dims]
)
};
}
}
export class ClapFeatureExtractor extends FeatureExtractor {
constructor(config) {
super(config);
this.mel_filters = mel_filter_bank(
this.config.nb_frequency_bins, // num_frequency_bins
this.config.feature_size, // num_mel_filters
this.config.frequency_min, // min_frequency
this.config.frequency_max, // max_frequency
this.config.sampling_rate, // sampling_rate
null, // norm
"htk", // mel_scale
);
this.mel_filters_slaney = mel_filter_bank(
this.config.nb_frequency_bins, // num_frequency_bins
this.config.feature_size, // num_mel_filters
this.config.frequency_min, // min_frequency
this.config.frequency_max, // max_frequency
this.config.sampling_rate, // sampling_rate
"slaney", // norm
"slaney", // mel_scale
);
this.window = window_function(this.config.fft_window_size, 'hann')
}
/**
* Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments.
*
* Four different path are possible:
* - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram
* will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram
* are then stacked together. They will later be used for `feature_fusion`.
* - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is
* padded based on `padding`.
* - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded
* based on `padding`, and is repeated `4` times.
* - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel
* spectrogram will be computed on a random crop of the waveform.
*
* @param {Float32Array|Float64Array} waveform The input waveform.
* @param {number} max_length The maximum length of the waveform.
* @param {string} truncation The truncation strategy to use.
* @param {string} padding The padding strategy to use.
* @returns {{ data: Float32Array; dims: number[]; longer: boolean; }} An object containing the mel spectrogram data as a Float32Array, its dimensions as an array of numbers, and a boolean indicating whether the waveform was longer than the max length.
*/
_get_input_mel(waveform, max_length, truncation, padding) {
/** @type {{ data: Float32Array; dims: number[]}} */
let input_mel;
let longer = false;
const diff = waveform.length - max_length;
if (diff > 0) {
if (truncation === 'rand_trunc') {
longer = true;
const idx = Math.floor(Math.random() * (diff + 1));
waveform = waveform.subarray(idx, idx + max_length);
input_mel = this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples);
input_mel.dims = [1, ...input_mel.dims]; // "unsqueeze"
} else {
// TODO implement fusion strategy
throw new Error(`Truncation strategy "${truncation}" not implemented`)
}
} else {
if (diff < 0) {
let padded = new Float64Array(max_length); // already padded with zeros
padded.set(waveform);
if (padding === 'repeat') {
for (let i = waveform.length; i < max_length; i += waveform.length) {
padded.set(waveform.subarray(0, Math.min(waveform.length, max_length - i)), i);
}
} else if (padding === 'repeatpad') {
for (let i = waveform.length; i < -diff; i += waveform.length) {
padded.set(waveform, i);
}
}
waveform = padded;
}
if (truncation === 'fusion') {
throw new Error(`Truncation strategy "${truncation}" not implemented`)
}
input_mel = this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples);
input_mel.dims = [1, ...input_mel.dims]; // "unsqueeze"
}
return {
...input_mel,
longer,
}
}
/**
* Compute the log-mel spectrogram of the provided `waveform` using the Hann window.
* In CLAP, two different filter banks are used depending on the truncation pattern:
* - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from
* calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
* is set to `"fusion"`.
* - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used
* `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
* implementation when the truncation mode is not `"fusion"`.
*
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
* @param {number[][]} mel_filters The mel filters to use.
* @param {number} [max_length=null] The maximum number of frames to return.
* @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
*/
_extract_fbank_features(waveform, mel_filters, max_length = null) {
// NOTE: We don't pad/truncate since that is passed in as `max_num_frames`
return spectrogram(
waveform,
this.window, // window
this.config.fft_window_size, // frame_length
this.config.hop_length, // hop_length
{
power: 2.0,
mel_filters,
log_mel: 'dB',
// Custom
max_num_frames: max_length,
do_pad: false,
transpose: true,
}
)
}
/**
* Asynchronously extracts features from a given audio using the provided configuration.
* @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array.
* @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor.
*/
async _call(audio, {
max_length = null,
} = {}) {
validate_audio_inputs(audio, 'ClapFeatureExtractor');
// convert to mel spectrogram, truncate and pad if needed.
const padded_inputs = this._get_input_mel(
audio,
max_length ?? this.config.nb_max_samples,
this.config.truncation,
this.config.padding,
);
return {
input_features: new Tensor('float32',
padded_inputs.data,
[1, ...padded_inputs.dims]
)
};
}
}
export class SpeechT5FeatureExtractor extends FeatureExtractor { }
/**
@ -1658,6 +1663,8 @@ export class AutoProcessor {
Swin2SRImageProcessor,
Wav2Vec2FeatureExtractor,
SpeechT5FeatureExtractor,
ASTFeatureExtractor,
ClapFeatureExtractor,
}
static PROCESSOR_CLASS_MAPPING = {

View File

@ -10,7 +10,11 @@
import {
getFile,
} from './hub.js';
import { rfftfreq } from './maths.js';
import { FFT, max } from './maths.js';
import {
calculateReflectOffset,
} from './core.js';
/**
* Helper function to read audio from a path/URL.
@ -57,8 +61,8 @@ export async function read_audio(url, sampling_rate) {
// audio at all, this scaling factor may not be needed.
const SCALING_FACTOR = Math.sqrt(2);
let left = decoded.getChannelData(0);
let right = decoded.getChannelData(1);
const left = decoded.getChannelData(0);
const right = decoded.getChannelData(1);
audio = new Float32Array(left.length);
for (let i = 0; i < decoded.length; ++i) {
@ -74,69 +78,587 @@ export async function read_audio(url, sampling_rate) {
}
/**
* Creates a frequency bin conversion matrix used to obtain a mel spectrogram.
* @param {number} sr Sample rate of the audio waveform.
* @param {number} n_fft Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
* @param {number} n_mels Number of mel filters to generate.
* @returns {number[][]} Projection matrix to go from a spectrogram to a mel spectrogram.
* Generates a Hanning window of length M.
*
* @param {number} M The length of the Hanning window to generate.
* @returns {Float64Array} The generated Hanning window.
*/
export function getMelFilters(sr, n_fft, n_mels = 128) {
n_mels = Math.floor(n_mels);
// Initialize the weights
const mel_size = Math.floor(1 + n_fft / 2);
const weights = new Array(n_mels);
// Center freqs of each FFT bin
const fftfreqs = rfftfreq(n_fft, 1 / sr);
// 'Center freqs' of mel bands - uniformly spaced between limits
const min_mel = 0.0;
const max_mel = 45.245640471924965;
const mel_range = max_mel - min_mel;
const mel_scale = mel_range / (n_mels + 1);
// Fill in the linear scale
const f_min = 0.0;
const f_sp = 200.0 / 3;
const freqs = new Array(n_mels + 2);
// And now the nonlinear scale
const min_log_hz = 1000.0; // beginning of log region (Hz)
const min_log_mel = (min_log_hz - f_min) / f_sp; // same (Mels)
const logstep = Math.log(6.4) / 27.0; // step size for log region
const ramps = new Array(freqs.length);
for (let i = 0; i < freqs.length; ++i) {
const mel = i * mel_scale + min_mel;
if (mel >= min_log_mel) {
freqs[i] = min_log_hz * Math.exp(logstep * (mel - min_log_mel));
} else {
freqs[i] = f_min + f_sp * mel;
}
ramps[i] = fftfreqs.map(k => freqs[i] - k);
export function hanning(M) {
if (M < 1) {
return new Float64Array();
}
const fdiffinv = freqs.slice(1).map((v, i) => 1 / (v - freqs[i]));
for (let i = 0; i < weights.length; ++i) {
weights[i] = new Array(mel_size);
const a = fdiffinv[i];
const b = fdiffinv[i + 1];
const c = ramps[i];
const d = ramps[i + 2];
// Slaney-style mel is scaled to be approx constant energy per channel
const enorm = 2.0 / (freqs[i + 2] - freqs[i]);
for (let j = 0; j < weights[i].length; ++j) {
// lower and upper slopes for all bins
const lower = -c[j] * a;
const upper = d[j] * b;
weights[i][j] = Math.max(0, Math.min(lower, upper)) * enorm;
}
if (M === 1) {
return new Float64Array([1]);
}
return weights;
const denom = M - 1;
const factor = Math.PI / denom;
const cos_vals = new Float64Array(M);
for (let i = 0; i < M; ++i) {
const n = 2 * i - denom;
cos_vals[i] = 0.5 + 0.5 * Math.cos(factor * n);
}
return cos_vals;
}
const HERTZ_TO_MEL_MAPPING = {
"htk": (/** @type {number} */ freq) => 2595.0 * Math.log10(1.0 + (freq / 700.0)),
"kaldi": (/** @type {number} */ freq) => 1127.0 * Math.log(1.0 + (freq / 700.0)),
"slaney": (/** @type {number} */ freq, min_log_hertz = 1000.0, min_log_mel = 15.0, logstep = 27.0 / Math.log(6.4)) =>
freq >= min_log_hertz
? min_log_mel + Math.log(freq / min_log_hertz) * logstep
: 3.0 * freq / 200.0,
}
/**
* @template {Float32Array|Float64Array|number} T
* @param {T} freq
* @param {string} [mel_scale]
* @returns {T}
*/
function hertz_to_mel(freq, mel_scale = "htk") {
const fn = HERTZ_TO_MEL_MAPPING[mel_scale];
if (!fn) {
throw new Error('mel_scale should be one of "htk", "slaney" or "kaldi".');
}
return typeof freq === 'number' ? fn(freq) : freq.map(x => fn(x));
}
const MEL_TO_HERTZ_MAPPING = {
"htk": (/** @type {number} */ mels) => 700.0 * (10.0 ** (mels / 2595.0) - 1.0),
"kaldi": (/** @type {number} */ mels) => 700.0 * (Math.exp(mels / 1127.0) - 1.0),
"slaney": (/** @type {number} */ mels, min_log_hertz = 1000.0, min_log_mel = 15.0, logstep = Math.log(6.4) / 27.0) => mels >= min_log_mel
? min_log_hertz * Math.exp(logstep * (mels - min_log_mel))
: 200.0 * mels / 3.0,
}
/**
* @template {Float32Array|Float64Array|number} T
* @param {T} mels
* @param {string} [mel_scale]
* @returns {T}
*/
function mel_to_hertz(mels, mel_scale = "htk") {
const fn = MEL_TO_HERTZ_MAPPING[mel_scale];
if (!fn) {
throw new Error('mel_scale should be one of "htk", "slaney" or "kaldi".');
}
return typeof mels === 'number' ? fn(mels) : mels.map(x => fn(x));
}
/**
* Creates a triangular filter bank.
*
* Adapted from torchaudio and librosa.
*
* @param {Float64Array} fft_freqs Discrete frequencies of the FFT bins in Hz, of shape `(num_frequency_bins,)`.
* @param {Float64Array} filter_freqs Center frequencies of the triangular filters to create, in Hz, of shape `(num_mel_filters,)`.
* @returns {number[][]} of shape `(num_frequency_bins, num_mel_filters)`.
*/
function _create_triangular_filter_bank(fft_freqs, filter_freqs) {
const filter_diff = Float64Array.from(
{ length: filter_freqs.length - 1 },
(_, i) => filter_freqs[i + 1] - filter_freqs[i]
);
const slopes = Array.from({
length: fft_freqs.length
}, () => new Array(filter_freqs.length));
for (let j = 0; j < fft_freqs.length; ++j) {
const slope = slopes[j];
for (let i = 0; i < filter_freqs.length; ++i) {
slope[i] = filter_freqs[i] - fft_freqs[j];
}
}
const numFreqs = filter_freqs.length - 2;
const ret = Array.from({ length: numFreqs }, () => new Array(fft_freqs.length));
for (let j = 0; j < fft_freqs.length; ++j) { // 201
const slope = slopes[j];
for (let i = 0; i < numFreqs; ++i) { // 80
const down = -slope[i] / filter_diff[i];
const up = slope[i + 2] / filter_diff[i + 1];
ret[i][j] = Math.max(0, Math.min(down, up));
}
}
return ret;
}
/**
* Return evenly spaced numbers over a specified interval.
* @param {number} start The starting value of the sequence.
* @param {number} end The end value of the sequence.
* @param {number} num Number of samples to generate.
* @returns `num` evenly spaced samples, calculated over the interval `[start, stop]`.
*/
function linspace(start, end, num) {
const step = (end - start) / (num - 1);
return Float64Array.from({ length: num }, (_, i) => start + step * i);
}
/**
* Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
* various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
* are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
* features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
* @param {number} num_frequency_bins Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
* @param {number} num_mel_filters Number of mel filters to generate.
* @param {number} min_frequency Lowest frequency of interest in Hz.
* @param {number} max_frequency Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
* @param {number} sampling_rate Sample rate of the audio waveform.
* @param {string} [norm] If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
* @param {string} [mel_scale] The mel frequency scale to use, `"htk"` or `"slaney"`.
* @param {boolean} [triangularize_in_mel_space] If this option is enabled, the triangular filter is applied in mel space rather than frequency space.
* This should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
* @returns {number[][]} Triangular filter bank matrix, which is a 2D array of shape (`num_frequency_bins`, `num_mel_filters`).
* This is a projection matrix to go from a spectrogram to a mel spectrogram.
*/
export function mel_filter_bank(
num_frequency_bins,
num_mel_filters,
min_frequency,
max_frequency,
sampling_rate,
norm = null,
mel_scale = "htk",
triangularize_in_mel_space = false,
) {
if (norm !== null && norm !== "slaney") {
throw new Error('norm must be one of null or "slaney"');
}
const mel_min = hertz_to_mel(min_frequency, mel_scale);
const mel_max = hertz_to_mel(max_frequency, mel_scale);
const mel_freqs = linspace(mel_min, mel_max, num_mel_filters + 2);
let filter_freqs = mel_to_hertz(mel_freqs, mel_scale);
let fft_freqs; // frequencies of FFT bins in Hz
if (triangularize_in_mel_space) {
const fft_bin_width = sampling_rate / (num_frequency_bins * 2);
fft_freqs = hertz_to_mel(Float64Array.from({ length: num_frequency_bins }, (_, i) => i * fft_bin_width), mel_scale);
filter_freqs = mel_freqs;
} else {
fft_freqs = linspace(0, Math.floor(sampling_rate / 2), num_frequency_bins);
}
const mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs);
if (norm !== null && norm === "slaney") {
// Slaney-style mel is scaled to be approx constant energy per channel
for (let i = 0; i < num_mel_filters; ++i) {
const filter = mel_filters[i];
const enorm = 2.0 / (filter_freqs[i + 2] - filter_freqs[i]);
for (let j = 0; j < num_frequency_bins; ++j) {
// Apply this enorm to all frequency bins
filter[j] *= enorm;
}
}
}
// TODO warn if there is a zero row
return mel_filters;
}
/**
* @template {Float32Array|Float64Array} T
* Pads an array with a reflected version of itself on both ends.
* @param {T} array The array to pad.
* @param {number} left The amount of padding to add to the left.
* @param {number} right The amount of padding to add to the right.
* @returns {T} The padded array.
*/
function padReflect(array, left, right) {
// @ts-ignore
const padded = new array.constructor(array.length + left + right);
const w = array.length - 1;
for (let i = 0; i < array.length; ++i) {
padded[left + i] = array[i];
}
for (let i = 1; i <= left; ++i) {
padded[left - i] = array[calculateReflectOffset(i, w)];
}
for (let i = 1; i <= right; ++i) {
padded[w + left + i] = array[calculateReflectOffset(w - i, w)];
}
return padded;
}
/**
* Helper function to compute `amplitude_to_db` and `power_to_db`.
* @template {Float32Array|Float64Array} T
* @param {T} spectrogram
* @param {number} factor
* @param {number} reference
* @param {number} min_value
* @param {number} db_range
* @returns {T}
*/
function _db_conversion_helper(spectrogram, factor, reference, min_value, db_range) {
if (reference <= 0) {
throw new Error('reference must be greater than zero');
}
if (min_value <= 0) {
throw new Error('min_value must be greater than zero');
}
reference = Math.max(min_value, reference);
const logReference = Math.log10(reference);
for (let i = 0; i < spectrogram.length; ++i) {
spectrogram[i] = factor * Math.log10(Math.max(min_value, spectrogram[i]) - logReference)
}
if (db_range !== null) {
if (db_range <= 0) {
throw new Error('db_range must be greater than zero');
}
const maxValue = max(spectrogram)[0] - db_range;
for (let i = 0; i < spectrogram.length; ++i) {
spectrogram[i] = Math.max(spectrogram[i], maxValue);
}
}
return spectrogram;
}
/**
* Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
* using basic logarithm properties for numerical stability. NOTE: Operates in-place.
*
* The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
* linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
* This means that large variations in energy may not sound all that different if the sound is loud to begin with.
* This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
*
* @template {Float32Array|Float64Array} T
* @param {T} spectrogram The input amplitude (mel) spectrogram.
* @param {number} [reference=1.0] Sets the input spectrogram value that corresponds to 0 dB.
* For example, use `np.max(spectrogram)` to set the loudest part to 0 dB. Must be greater than zero.
* @param {number} [min_value=1e-5] The spectrogram will be clipped to this minimum value before conversion to decibels,
* to avoid taking `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
* @param {number} [db_range=null] Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the
* difference between the peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
* @returns {T} The modified spectrogram in decibels.
*/
function amplitude_to_db(spectrogram, reference = 1.0, min_value = 1e-5, db_range = null) {
return _db_conversion_helper(spectrogram, 20.0, reference, min_value, db_range);
}
/**
* Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
* using basic logarithm properties for numerical stability. NOTE: Operates in-place.
*
* The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
* linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
* This means that large variations in energy may not sound all that different if the sound is loud to begin with.
* This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
*
* Based on the implementation of `librosa.power_to_db`.
*
* @template {Float32Array|Float64Array} T
* @param {T} spectrogram The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
* @param {number} [reference=1.0] Sets the input spectrogram value that corresponds to 0 dB.
* For example, use `np.max(spectrogram)` to set the loudest part to 0 dB. Must be greater than zero.
* @param {number} [min_value=1e-10] The spectrogram will be clipped to this minimum value before conversion to decibels,
* to avoid taking `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
* @param {number} [db_range=null] Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the
* difference between the peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
* @returns {T} The modified spectrogram in decibels.
*/
function power_to_db(spectrogram, reference = 1.0, min_value = 1e-10, db_range = null) {
return _db_conversion_helper(spectrogram, 10.0, reference, min_value, db_range);
}
/**
* Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
*
* This function can create the following kinds of spectrograms:
* - amplitude spectrogram (`power = 1.0`)
* - power spectrogram (`power = 2.0`)
* - complex-valued spectrogram (`power = None`)
* - log spectrogram (use `log_mel` argument)
* - mel spectrogram (provide `mel_filters`)
* - log-mel spectrogram (provide `mel_filters` and `log_mel`)
*
* In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame.
* A padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
* typically the next power of two.
*
* @param {Float32Array|Float64Array} waveform The input waveform of shape `(length,)`. This must be a single real-valued, mono waveform.
* @param {Float32Array|Float64Array} window The windowing function to apply of shape `(frame_length,)`, including zero-padding if necessary. The actual window length may be
* shorter than `frame_length`, but we're assuming the array has already been zero-padded.
* @param {number} frame_length The length of the analysis frames in samples (a.k.a., `fft_length`).
* @param {number} hop_length The stride between successive analysis frames in samples.
* @param {Object} options
* @param {number} [options.fft_length=null] The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
* For optimal speed, this should be a power of two. If `null`, uses `frame_length`.
* @param {number} [options.power=1.0] If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `null`, returns complex numbers.
* @param {boolean} [options.center=true] Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `false`, frame
* `t` will start at time `t * hop_length`.
* @param {string} [options.pad_mode="reflect"] Padding mode used when `center` is `true`. Possible values are: `"constant"` (pad with zeros),
* `"edge"` (pad with edge values), `"reflect"` (pads with mirrored values).
* @param {boolean} [options.onesided=true] If `true`, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
* frequency bins. If `false`, also computes the negative frequencies and returns `fft_length` frequency bins.
* @param {number} [options.preemphasis=null] Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
* @param {number[][]} [options.mel_filters=null] The mel filter bank of shape `(num_freq_bins, num_mel_filters)`.
* If supplied, applies this filter bank to create a mel spectrogram.
* @param {number} [options.mel_floor=1e-10] Minimum value of mel frequency banks.
* @param {string} [options.log_mel=null] How to convert the spectrogram to log scale. Possible options are:
* `null` (don't convert), `"log"` (take the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels).
* Can only be used when `power` is not `null`.
* @param {number} [options.reference=1.0] Sets the input spectrogram value that corresponds to 0 dB. For example, use `max(spectrogram)[0]` to set
* the loudest part to 0 dB. Must be greater than zero.
* @param {number} [options.min_value=1e-10] The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking `log(0)`.
* For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an amplitude spectrogram, the value `1e-5` corresponds to -100 dB.
* Must be greater than zero.
* @param {number} [options.db_range=null] Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
* peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
* @param {boolean} [options.remove_dc_offset=null] Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
* order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
* @param {number} [options.max_num_frames=null] If provided, limits the number of frames to compute to this value.
* @param {boolean} [options.do_pad=true] If `true`, pads the output spectrogram to have `max_num_frames` frames.
* @param {boolean} [options.transpose=false] If `true`, the returned spectrogram will have shape `(num_frames, num_frequency_bins/num_mel_filters)`. If `false`, the returned spectrogram will have shape `(num_frequency_bins/num_mel_filters, num_frames)`.
* @returns {{data: Float32Array, dims: number[]}} Spectrogram of shape `(num_frequency_bins, length)` (regular spectrogram) or shape `(num_mel_filters, length)` (mel spectrogram).
*/
export function spectrogram(
waveform,
window,
frame_length,
hop_length,
{
fft_length = null,
power = 1.0,
center = true,
pad_mode = "reflect",
onesided = true,
preemphasis = null,
mel_filters = null,
mel_floor = 1e-10,
log_mel = null,
reference = 1.0,
min_value = 1e-10,
db_range = null,
remove_dc_offset = null,
// Custom parameters for efficiency reasons
max_num_frames = null,
do_pad = true,
transpose = false,
} = {}
) {
const window_length = window.length;
if (fft_length === null) {
fft_length = frame_length;
}
if (frame_length > fft_length) {
throw Error(`frame_length (${frame_length}) may not be larger than fft_length (${fft_length})`)
}
if (window_length !== frame_length) {
throw new Error(`Length of the window (${window_length}) must equal frame_length (${frame_length})`);
}
if (hop_length <= 0) {
throw new Error("hop_length must be greater than zero");
}
if (center) {
if (pad_mode !== 'reflect') {
throw new Error(`pad_mode="${pad_mode}" not implemented yet.`)
}
const half_window = Math.floor((fft_length - 1) / 2) + 1;
waveform = padReflect(waveform, half_window, half_window);
}
// split waveform into frames of frame_length size
const num_frames = Math.floor(1 + Math.floor((waveform.length - frame_length) / hop_length))
const num_frequency_bins = onesided ? Math.floor(fft_length / 2) + 1 : fft_length
let d1 = num_frames;
let d1Max = num_frames;
// If maximum number of frames is provided, we must either pad or truncate
if (max_num_frames !== null) {
if (max_num_frames > num_frames) { // input is too short, so we pad
if (do_pad) {
d1Max = max_num_frames;
}
} else { // input is too long, so we truncate
d1Max = d1 = max_num_frames;
}
}
// Preallocate arrays to store output.
const fft = new FFT(fft_length);
const inputBuffer = new Float64Array(fft_length);
const outputBuffer = new Float64Array(fft.outputBufferSize);
const magnitudes = new Array(d1);
for (let i = 0; i < d1; ++i) {
// Populate buffer with waveform data
const offset = i * hop_length;
for (let j = 0; j < frame_length; ++j) {
inputBuffer[j] = waveform[offset + j];
}
if (remove_dc_offset) {
let sum = 0;
for (let j = 0; j < frame_length; ++j) {
sum += inputBuffer[j];
}
const mean = sum / frame_length;
for (let j = 0; j < frame_length; ++j) {
inputBuffer[j] -= mean;
}
}
if (preemphasis !== null) {
// Done in reverse to avoid copies and distructive modification
for (let j = frame_length - 1; j >= 1; --j) {
inputBuffer[j] -= preemphasis * inputBuffer[j - 1];
}
inputBuffer[0] *= 1 - preemphasis;
}
for (let j = 0; j < window.length; ++j) {
inputBuffer[j] *= window[j];
}
fft.realTransform(outputBuffer, inputBuffer);
// compute magnitudes
const row = new Array(num_frequency_bins);
for (let j = 0; j < row.length; ++j) {
const j2 = j << 1;
row[j] = outputBuffer[j2] ** 2 + outputBuffer[j2 + 1] ** 2;
}
magnitudes[i] = row;
}
// TODO what should happen if power is None?
// https://github.com/huggingface/transformers/issues/27772
if (power !== null && power !== 2) {
// slight optimization to not sqrt
const pow = 2 / power; // we use 2 since we already squared
for (let i = 0; i < magnitudes.length; ++i) {
const magnitude = magnitudes[i];
for (let j = 0; j < magnitude.length; ++j) {
magnitude[j] **= pow;
}
}
}
// TODO: What if `mel_filters` is null?
const num_mel_filters = mel_filters.length;
// Only here do we create Float32Array
const mel_spec = new Float32Array(num_mel_filters * d1Max);
// Perform matrix muliplication:
// mel_spec = mel_filters @ magnitudes.T
// - mel_filters.shape=(80, 201)
// - magnitudes.shape=(3000, 201) => - magnitudes.T.shape=(201, 3000)
// - mel_spec.shape=(80, 3000)
const dims = transpose ? [d1Max, num_mel_filters] : [num_mel_filters, d1Max];
for (let i = 0; i < num_mel_filters; ++i) { // num melfilters (e.g., 80)
const filter = mel_filters[i];
for (let j = 0; j < d1; ++j) { // num frames (e.g., 3000)
const magnitude = magnitudes[j];
let sum = 0;
for (let k = 0; k < num_frequency_bins; ++k) { // num frequency bins (e.g., 201)
sum += filter[k] * magnitude[k];
}
mel_spec[
transpose
? j * num_mel_filters + i
: i * d1 + j
] = Math.max(mel_floor, sum);
}
}
if (power !== null && log_mel !== null) {
const o = Math.min(mel_spec.length, d1 * num_mel_filters);
switch (log_mel) {
case 'log':
for (let i = 0; i < o; ++i) {
mel_spec[i] = Math.log(mel_spec[i]);
}
break;
case 'log10':
for (let i = 0; i < o; ++i) {
mel_spec[i] = Math.log10(mel_spec[i]);
}
break;
case 'dB':
if (power === 1.0) {
// NOTE: operates in-place
amplitude_to_db(mel_spec, reference, min_value, db_range);
} else if (power === 2.0) {
power_to_db(mel_spec, reference, min_value, db_range);
} else {
throw new Error(`Cannot use log_mel option '${log_mel}' with power ${power}`)
}
break;
default:
throw new Error(`log_mel must be one of null, 'log', 'log10' or 'dB'. Got '${log_mel}'`);
}
}
return { data: mel_spec, dims };
}
/**
* Returns an array containing the specified window.
* @param {number} window_length The length of the window in samples.
* @param {string} name The name of the window function.
* @param {Object} options Additional options.
* @param {boolean} [options.periodic=true] Whether the window is periodic or symmetric.
* @param {number} [options.frame_length=null] The length of the analysis frames in samples.
* Provide a value for `frame_length` if the window is smaller than the frame length, so that it will be zero-padded.
* @param {boolean} [options.center=true] Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
* @returns {Float64Array} The window of shape `(window_length,)` or `(frame_length,)`.
*/
export function window_function(window_length, name, {
periodic = true,
frame_length = null,
center = true,
} = {}) {
const length = periodic ? window_length + 1 : window_length;
let window;
switch (name) {
case 'boxcar':
window = new Float64Array(length).fill(1.0);
break;
case 'hann':
case 'hann_window':
window = hanning(length);
break;
default:
throw new Error(`Unknown window type ${name}.`);
}
if (periodic) {
window = window.subarray(0, window_length);
}
if (frame_length === null) {
return window;
}
if (window_length > frame_length) {
throw new Error(`Length of the window (${window_length}) may not be larger than frame_length (${frame_length})`);
}
return window;
}

View File

@ -89,8 +89,8 @@ export function interpolate_data(input, [in_channels, in_height, in_width], [out
/**
* Helper method to transpose a `AnyTypedArray` directly
* @param {T} array
* @template {AnyTypedArray} T
* @param {T} array
* @param {number[]} dims
* @param {number[]} axes
* @returns {[T, number[]]} The transposed array and the new shape.
@ -269,48 +269,31 @@ export function max(arr) {
return [max, indexOfMax];
}
/**
* Return the Discrete Fourier Transform sample frequencies.
*
* Code adapted from https://github.com/numpy/numpy/blob/25908cacd19915bf3ddd659c28be28a41bd97a54/numpy/fft/helper.py#L173-L221
* Original Python doc: https://numpy.org/doc/stable/reference/generated/numpy.fft.rfftfreq.html
* @example
* rfftfreq(400, 1 / 16000) // (201) [0, 40, 80, 120, 160, 200, ..., 8000]
* @param {number} n Window length
* @param {number} [d = 1.0] Sample spacing (inverse of the sampling rate). Defaults to 1.
* @throws {TypeError} If n is not an integer.
* @returns {number[]} Array of length `Math.floor(n / 2) + 1;` containing the sample frequencies.
*/
export function rfftfreq(n, d = 1.0) {
if (!Number.isInteger(n)) {
throw new TypeError(`n should be an integer, but ${n} given.`);
}
const val = 1.0 / (n * d);
const len = Math.floor(n / 2) + 1;
const results = new Array(len);
for (let i = 0; i < len; ++i) {
results[i] = i * val;
}
return results;
function isPowerOfTwo(number) {
// Check if the number is greater than 0 and has only one bit set to 1
return (number > 0) && ((number & (number - 1)) === 0);
}
/**
* FFT class provides functionality for performing Fast Fourier Transform on arrays
* Implementation of Radix-4 FFT.
*
* P2FFT class provides functionality for performing Fast Fourier Transform on arrays
* which are a power of two in length.
* Code adapted from https://www.npmjs.com/package/fft.js
*/
export class FFT {
class P2FFT {
/**
* @param {number} size The size of the input array. Must be a power of two and bigger than 1.
* @throws {Error} FFT size must be a power of two and bigger than 1.
* @param {number} size The size of the input array. Must be a power of two larger than 1.
* @throws {Error} FFT size must be a power of two larger than 1.
*/
constructor(size) {
this.size = size | 0; // convert to a 32-bit signed integer
if (this.size <= 1 || (this.size & (this.size - 1)) !== 0)
throw new Error('FFT size must be a power of two and bigger than 1');
if (this.size <= 1 || !isPowerOfTwo(this.size))
throw new Error('FFT size must be a power of two larger than 1');
this._csize = size << 1;
this.table = new Float32Array(this.size * 2);
this.table = new Float64Array(this.size * 2);
for (let i = 0; i < this.table.length; i += 2) {
const angle = Math.PI * i / this.size;
this.table[i] = Math.cos(angle);
@ -341,16 +324,16 @@ export class FFT {
/**
* Create a complex number array with size `2 * size`
*
* @returns {Float32Array} A complex number array with size `2 * size`
* @returns {Float64Array} A complex number array with size `2 * size`
*/
createComplexArray() {
return new Float32Array(this._csize);
return new Float64Array(this._csize);
}
/**
* Converts a complex number representation stored in a Float32Array to an array of real numbers.
* Converts a complex number representation stored in a Float64Array to an array of real numbers.
*
* @param {Float32Array} complex The complex number representation to be converted.
* @param {Float64Array} complex The complex number representation to be converted.
* @param {number[]} [storage] An optional array to store the result in.
* @returns {number[]} An array of real numbers representing the input complex number representation.
*/
@ -363,9 +346,9 @@ export class FFT {
/**
* Convert a real-valued input array to a complex-valued output array.
* @param {Float32Array} input The real-valued input array.
* @param {Float32Array} [storage] Optional buffer to store the output array.
* @returns {Float32Array} The complex-valued output array.
* @param {Float64Array} input The real-valued input array.
* @param {Float64Array} [storage] Optional buffer to store the output array.
* @returns {Float64Array} The complex-valued output array.
*/
toComplexArray(input, storage) {
const res = storage || this.createComplexArray();
@ -378,7 +361,7 @@ export class FFT {
/**
* Completes the spectrum by adding its mirrored negative frequency components.
* @param {Float32Array} spectrum The input spectrum.
* @param {Float64Array} spectrum The input spectrum.
* @returns {void}
*/
completeSpectrum(spectrum) {
@ -393,8 +376,8 @@ export class FFT {
/**
* Performs a Fast Fourier Transform (FFT) on the given input data and stores the result in the output buffer.
*
* @param {Float32Array} out The output buffer to store the result.
* @param {Float32Array} data The input data to transform.
* @param {Float64Array} out The output buffer to store the result.
* @param {Float64Array} data The input data to transform.
*
* @throws {Error} Input and output buffers must be different.
*
@ -412,8 +395,8 @@ export class FFT {
* The input buffer must contain real values only, while the output buffer will contain complex values. The input and
* output buffers must be different.
*
* @param {Float32Array} out The output buffer.
* @param {Float32Array} data The input buffer containing real values.
* @param {Float64Array} out The output buffer.
* @param {Float64Array} data The input buffer containing real values.
*
* @throws {Error} If the input and output buffers are the same.
*/
@ -429,8 +412,8 @@ export class FFT {
* The `out` array must be a different buffer than the `data` array. The `out` array will contain the
* result of the transformation. The `data` array will not be modified.
*
* @param {Float32Array} out The output buffer for the transformed data.
* @param {Float32Array} data The input data to transform.
* @param {Float64Array} out The output buffer for the transformed data.
* @param {Float64Array} data The input data to transform.
* @throws {Error} If `out` and `data` refer to the same buffer.
* @returns {void}
*/
@ -446,8 +429,8 @@ export class FFT {
/**
* Performs a radix-4 implementation of a discrete Fourier transform on a given set of data.
*
* @param {Float32Array} out The output buffer for the transformed data.
* @param {Float32Array} data The input buffer of data to be transformed.
* @param {Float64Array} out The output buffer for the transformed data.
* @param {Float64Array} data The input buffer of data to be transformed.
* @param {number} inv A scaling factor to apply to the transform.
* @returns {void}
*/
@ -463,7 +446,7 @@ export class FFT {
let outOff;
let t;
let bitrev = this._bitrev;
const bitrev = this._bitrev;
if (len === 4) {
for (outOff = 0, t = 0; outOff < size; outOff += len, ++t) {
const off = bitrev[t];
@ -480,12 +463,12 @@ export class FFT {
// Loop through steps in decreasing order
for (step >>= 2; step >= 2; step >>= 2) {
len = (size / step) << 1;
let quarterLen = len >>> 2;
const quarterLen = len >>> 2;
// Loop through offsets in the data
for (outOff = 0; outOff < size; outOff += len) {
// Full case
let limit = outOff + quarterLen;
const limit = outOff + quarterLen - 1;
for (let i = outOff, k = 0; i < limit; i += 2, k += step) {
const A = i;
const B = A + quarterLen;
@ -544,8 +527,8 @@ export class FFT {
/**
* Performs a radix-2 implementation of a discrete Fourier transform on a given set of data.
*
* @param {Float32Array} data The input buffer of data to be transformed.
* @param {Float32Array} out The output buffer for the transformed data.
* @param {Float64Array} data The input buffer of data to be transformed.
* @param {Float64Array} out The output buffer for the transformed data.
* @param {number} outOff The offset at which to write the output data.
* @param {number} off The offset at which to begin reading the input data.
* @param {number} step The step size for indexing the input data.
@ -569,8 +552,8 @@ export class FFT {
/**
* Performs radix-4 transformation on input data of length 8
*
* @param {Float32Array} data Input data array of length 8
* @param {Float32Array} out Output data array of length 8
* @param {Float64Array} data Input data array of length 8
* @param {Float64Array} out Output data array of length 8
* @param {number} outOff Index of output array to start writing from
* @param {number} off Index of input array to start reading from
* @param {number} step Step size between elements in input array
@ -617,8 +600,8 @@ export class FFT {
/**
* Real input radix-4 implementation
* @param {Float32Array} out Output array for the transformed data
* @param {Float32Array} data Input array of real data to be transformed
* @param {Float64Array} out Output array for the transformed data
* @param {Float64Array} data Input array of real data to be transformed
* @param {number} inv The scale factor used to normalize the inverse transform
*/
_realTransform4(out, data, inv) {
@ -630,9 +613,9 @@ export class FFT {
let step = 1 << width;
let len = (size / step) << 1;
var outOff;
var t;
var bitrev = this._bitrev;
let outOff;
let t;
const bitrev = this._bitrev;
if (len === 4) {
for (outOff = 0, t = 0; outOff < size; outOff += len, ++t) {
const off = bitrev[t];
@ -646,17 +629,18 @@ export class FFT {
}
}
// TODO: Optimize once https://github.com/indutny/fft.js/issues/25 is fixed
// Loop through steps in decreasing order
for (step >>= 2; step >= 2; step >>= 2) {
len = (size / step) << 1;
const halfLen = len >>> 1;
const quarterLen = halfLen >>> 1;
const hquarterLen = quarterLen >>> 1;
const quarterLen = len >>> 2;
// Loop through offsets in the data
for (outOff = 0; outOff < size; outOff += len) {
for (let i = 0, k = 0; i <= hquarterLen; i += 2, k += step) {
const A = outOff + i;
// Full case
const limit = outOff + quarterLen - 1;
for (let i = outOff, k = 0; i < limit; i += 2, k += step) {
const A = i;
const B = A + quarterLen;
const C = B + quarterLen;
const D = C + quarterLen;
@ -701,25 +685,10 @@ export class FFT {
out[A + 1] = T0i + T2i;
out[B] = T1r + T3i;
out[B + 1] = T1i - T3r;
// Output final middle point
if (i === 0) {
out[C] = T0r - T2r;
out[C + 1] = T0i - T2i;
continue;
}
// Do not overwrite ourselves
if (i === hquarterLen)
continue;
const SA = outOff + quarterLen - i;
const SB = outOff + halfLen - i;
out[SA] = T1r + -inv * T3i;
out[SA + 1] = -T1i - inv * T3r;
out[SB] = T0r + -inv * T2r;
out[SB + 1] = -T0i + inv * T2i;
out[C] = T0r - T2r;
out[C + 1] = T0i - T2i;
out[D] = T1r - T3i;
out[D + 1] = T1i + T3r;
}
}
}
@ -728,8 +697,8 @@ export class FFT {
/**
* Performs a single real input radix-2 transformation on the provided data
*
* @param {Float32Array} data The input data array
* @param {Float32Array} out The output data array
* @param {Float64Array} data The input data array
* @param {Float64Array} out The output data array
* @param {number} outOff The output offset
* @param {number} off The input offset
* @param {number} step The step
@ -753,8 +722,8 @@ export class FFT {
* Computes a single real-valued transform using radix-4 algorithm.
* This method is only called for len=8.
*
* @param {Float32Array} data The input data array.
* @param {Float32Array} out The output data array.
* @param {Float64Array} data The input data array.
* @param {Float64Array} out The output data array.
* @param {number} outOff The offset into the output array.
* @param {number} off The offset into the input array.
* @param {number} step The step size for the input array.
@ -790,6 +759,148 @@ export class FFT {
}
}
/**
* NP2FFT class provides functionality for performing Fast Fourier Transform on arrays
* which are not a power of two in length. In such cases, the chirp-z transform is used.
*
* For more information, see: https://math.stackexchange.com/questions/77118/non-power-of-2-ffts/77156#77156
*/
class NP2FFT {
/**
* Constructs a new NP2FFT object.
* @param {number} fft_length The length of the FFT
*/
constructor(fft_length) {
// Helper variables
const a = 2 * (fft_length - 1);
const b = 2 * (2 * fft_length - 1);
const nextP2 = 2 ** (Math.ceil(Math.log2(b)))
this.bufferSize = nextP2;
this._a = a;
// Define buffers
// Compute chirp for transform
const chirp = new Float64Array(b);
const ichirp = new Float64Array(nextP2);
this._chirpBuffer = new Float64Array(nextP2);
this._buffer1 = new Float64Array(nextP2);
this._buffer2 = new Float64Array(nextP2);
this._outBuffer1 = new Float64Array(nextP2);
this._outBuffer2 = new Float64Array(nextP2);
// Compute complex exponentiation
const theta = -2 * Math.PI / fft_length;
const baseR = Math.cos(theta);
const baseI = Math.sin(theta);
// Precompute helper for chirp-z transform
for (let i = 0; i < b >> 1; ++i) {
// Compute complex power:
const e = (i + 1 - fft_length) ** 2 / 2.0;
// Compute the modulus and argument of the result
const result_mod = Math.sqrt(baseR ** 2 + baseI ** 2) ** e;
const result_arg = e * Math.atan2(baseI, baseR);
// Convert the result back to rectangular form
// and assign to chirp and ichirp
const i2 = 2 * i;
chirp[i2] = result_mod * Math.cos(result_arg);
chirp[i2 + 1] = result_mod * Math.sin(result_arg);
// conjugate
ichirp[i2] = chirp[i2];
ichirp[i2 + 1] = - chirp[i2 + 1];
}
this._slicedChirpBuffer = chirp.subarray(a, b);
// create object to perform Fast Fourier Transforms
// with `nextP2` complex numbers
this._f = new P2FFT(nextP2 >> 1);
this._f.transform(this._chirpBuffer, ichirp);
}
_transform(output, input, real) {
const ib1 = this._buffer1;
const ib2 = this._buffer2;
const ob2 = this._outBuffer1;
const ob3 = this._outBuffer2;
const cb = this._chirpBuffer;
const sb = this._slicedChirpBuffer;
const a = this._a;
if (real) {
// Real multiplication
for (let j = 0; j < sb.length; j += 2) {
const j2 = j + 1
const j3 = j >> 1;
const a_real = input[j3];
ib1[j] = a_real * sb[j];
ib1[j2] = a_real * sb[j2];
}
} else {
// Complex multiplication
for (let j = 0; j < sb.length; j += 2) {
const j2 = j + 1
ib1[j] = input[j] * sb[j] - input[j2] * sb[j2];
ib1[j2] = input[j] * sb[j2] + input[j2] * sb[j];
}
}
this._f.transform(ob2, ib1);
for (let j = 0; j < cb.length; j += 2) {
const j2 = j + 1;
ib2[j] = ob2[j] * cb[j] - ob2[j2] * cb[j2];
ib2[j2] = ob2[j] * cb[j2] + ob2[j2] * cb[j];
}
this._f.inverseTransform(ob3, ib2);
for (let j = 0; j < ob3.length; j += 2) {
const a_real = ob3[j + a];
const a_imag = ob3[j + a + 1];
const b_real = sb[j];
const b_imag = sb[j + 1];
output[j] = a_real * b_real - a_imag * b_imag;
output[j + 1] = a_real * b_imag + a_imag * b_real;
}
}
transform(output, input) {
this._transform(output, input, false);
}
realTransform(output, input) {
this._transform(output, input, true);
}
}
export class FFT {
constructor(fft_length) {
this.fft_length = fft_length;
this.isPowerOfTwo = isPowerOfTwo(fft_length);
if (this.isPowerOfTwo) {
this.fft = new P2FFT(fft_length);
this.outputBufferSize = 2 * fft_length;
} else {
this.fft = new NP2FFT(fft_length);
this.outputBufferSize = this.fft.bufferSize;
}
}
realTransform(out, input) {
this.fft.realTransform(out, input);
}
transform(out, input) {
this.fft.transform(out, input);
}
}
/**
* Performs median filter on the provided data. Padding is done by mirroring the data.
* @param {AnyTypedArray} data The input array

View File

@ -5,6 +5,7 @@ import json
import os
from transformers import AutoTokenizer, AutoConfig
import numpy as np
from scripts.supported_models import SUPPORTED_MODELS
@ -198,6 +199,37 @@ def generate_config_tests():
return results
ARRAY_SIZES = sorted(set([2 ** i for i in range(1, 10)]) \
| set([3 ** i for i in range(1, 8)]) \
| set([5 ** i for i in range(1, 6)]) \
| set([7 ** i for i in range(1, 4)]))
def serialize_complex_array(arr):
return [float(x) for y in arr for x in [y.real, y.imag]]
def serialize_real_array(arr):
return arr.tolist()
def generate_fft_tests():
np.random.seed(0)
tests = {}
for complex in [False, True]:
serialize_fn = serialize_complex_array if complex else serialize_real_array
for size in ARRAY_SIZES:
arr = np.random.randn(size).astype(np.complex64 if complex else np.float64)
if complex:
arr += np.random.randn(size) * 1j
tests[f"fft_{size}_{'complex' if complex else 'real'}"] = {
"complex": complex,
"input": serialize_fn(arr),
"output": serialize_complex_array(np.fft.fft(arr)),
}
return tests
def main():
# TODO add option to cache generated data + force build tests
@ -213,6 +245,9 @@ def main():
with open(os.path.join(data_dir, "config_tests.json"), "w", encoding="utf-8") as fp:
json.dump(config_tests, fp)
fft_tests = generate_fft_tests()
with open(os.path.join(data_dir, "fft_tests.json"), "w", encoding="utf-8") as fp:
json.dump(fft_tests, fp)
if __name__ == "__main__":
main()

View File

@ -1,7 +1,29 @@
import { compare } from './test_utils.js';
import { medianFilter } from '../src/utils/maths.js';
import { getFile } from '../src/utils/hub.js';
import { FFT, medianFilter } from '../src/utils/maths.js';
const fft = (arr, complex = false) => {
let output;
let fft;
if (complex) {
fft = new FFT(arr.length / 2);
output = new Float64Array(fft.outputBufferSize);
fft.transform(output, arr);
} else {
fft = new FFT(arr.length);
output = new Float64Array(fft.outputBufferSize);
fft.realTransform(output, arr);
}
if (!fft.isPowerOfTwo) {
output = output.slice(0, complex ? arr.length : 2 * arr.length);
}
return output;
}
const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json()
describe('Mathematical operations', () => {
@ -11,8 +33,8 @@ describe('Mathematical operations', () => {
it('should compute median filter', async () => {
const t1 = new Float32Array([5, 12, 2, 6, 3, 10, 9, 1, 4, 8, 11, 7]);
const window = 3;
const target = new Float32Array([12, 5, 6, 3, 6, 9, 9, 4, 4, 8, 8, 11]);
const target = new Float32Array([12, 5, 6, 3, 6, 9, 9, 4, 4, 8, 8, 11]);
const output = medianFilter(t1, window);
compare(output, target, 1e-3);
@ -22,4 +44,83 @@ describe('Mathematical operations', () => {
// TODO add tests for errors
});
describe('FFT', () => {
// Should match output of numpy fft
it('should compute real FFT for power of two', () => {
{ // size = 4
// np.fft.fft([1,2,3,4]) == array([10.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
const input = new Float32Array([1, 2, 3, 4]);
const target = new Float32Array([10, 0, -2, 2, -2, 0, -2, -2]);
const output = fft(input);
compare(output, target, 1e-3);
}
{ // size = 16
// np.fft.fft([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
// == array([136. +0.j , -8.+40.21871594j, -8.+19.3137085j ,
// -8.+11.9728461j , -8. +8.j , -8. +5.3454291j ,
// -8. +3.3137085j , -8. +1.59129894j, -8. +0.j ,
// -8. -1.59129894j, -8. -3.3137085j , -8. -5.3454291j ,
// -8. -8.j , -8.-11.9728461j , -8.-19.3137085j ,
// -8.-40.21871594j])
const input = new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
const target = new Float32Array([136.0, 0.0, -8.0, 40.218715937006785, -8.0, 19.31370849898476, -8.0, 11.972846101323912, -8.0, 8.0, -8.0, 5.345429103354389, -8.0, 3.313708498984761, -8.0, 1.5912989390372658, -8.0, 0.0, -8.0, -1.5912989390372658, -8.0, -3.313708498984761, -8.0, -5.345429103354389, -8.0, -8.0, -8.0, -11.972846101323912, -8.0, -19.31370849898476, -8.0, -40.218715937006785]);
const output = fft(input);
compare(output, target, 1e-3);
}
});
it('should compute real FFT for non-power of two', () => {
{ // size = 3
// np.fft.fft([1,2,3]) == array([ 6. +0.j, -1.5+0.8660254j, -1.5-0.8660254j])
const input = new Float32Array([1, 2, 3]);
const target = new Float32Array([6, 0, -1.5, 0.8660254, -1.5, -0.8660254]);
const output = fft(input);
compare(output, target, 1e-3);
}
});
it('should compute complex FFT for non-power of two', () => {
{ // size = 3
// np.fft.fft([1+3j,2-2j,3+1j]) == array([ 6. +2.j, -4.09807621+4.3660254j, 1.09807621+2.6339746j])
const input = new Float32Array([1, 3, 2, -2, 3, 1]);
const target = new Float32Array([6, 2, -4.09807621, 4.3660254, 1.09807621, 2.6339746]);
const output = fft(input, true);
compare(output, target, 1e-3);
}
});
it('should compute complex FFT for power of two', () => {
{ // size = 4
// np.fft.fft([1+4j, 2-3j,3+2j, 4-1j]) == array([10. +2.j, -4. +4.j, -2.+10.j, 0. +0.j])
const input = new Float32Array([1, 4, 2, -3, 3, 2, 4, -1]);
const target = new Float32Array([10, 2, -4, 4, -2, 10, 0, 0]);
const output = fft(input, true);
compare(output, target, 1e-3);
}
});
})
describe('FFT (dynamic)', () => {
// Should match output of numpy fft
for (const [name, test] of Object.entries(fftTestsData)) {
// if (test.input.length > 5) continue;
it(name, () => {
const output = fft(test.input, test.complex);
if (output.map((v, i) => Math.abs(v - test.output[i])).some(v => v > 1e-4)) {
console.log('input', test.input)
console.log('output', output)
console.log('target', test.output)
}
compare(output, test.output, 1e-4);
});
}
});
});

View File

@ -337,4 +337,103 @@ describe('Processors', () => {
}
}, MAX_TEST_EXECUTION_TIME);
});
describe('Audio processors', () => {
const audioPromise = new Promise(async (resolve) => {
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/mlk.npy';
const buffer = await (await fetch(url)).arrayBuffer();
const audio = Float32Array.from(new Float64Array(buffer));
resolve(audio);
});
it('WhisperFeatureExtractor', async () => {
const audio = await audioPromise;
const processor = await AutoProcessor.from_pretrained('Xenova/whisper-tiny.en');
const { input_features } = await processor(audio);
compare(input_features.dims, [1, 80, 3000]);
expect(avg(input_features.data)).toBeCloseTo(-0.2813588131551941);
expect(input_features.data[0]).toBeCloseTo(0.33168578147888184);
expect(input_features.data[1]).toBeCloseTo(0.30986475944519043);
expect(input_features.data[81]).toBeCloseTo(0.10727232694625854);
expect(input_features.data[3001]).toBeCloseTo(0.2555035352706909);
}, MAX_TEST_EXECUTION_TIME);
it('ASTFeatureExtractor', async () => {
const audio = await audioPromise;
const processor = await AutoProcessor.from_pretrained('Xenova/ast-finetuned-audioset-10-10-0.4593');
{ // truncation
const { input_values } = await processor(audio);
compare(input_values.dims, [1, 1024, 128]);
expect(avg(input_values.data)).toBeCloseTo(-0.04054912979309085);
expect(input_values.data[0]).toBeCloseTo(-0.5662586092948914);
expect(input_values.data[1]).toBeCloseTo(-1.0300861597061157);
expect(input_values.data[129]).toBeCloseTo(-1.084834098815918);
expect(input_values.data[1025]).toBeCloseTo(-1.1204065084457397);
}
{ // padding
const { input_values } = await processor(audio.slice(0, 1000));
compare(input_values.dims, [1, 1024, 128]); // [1, 4, 128] -> (padded to) -> [1, 1024, 128]
expect(avg(input_values.data)).toBeCloseTo(0.4647964835166931);
expect(input_values.data[0]).toBeCloseTo(-0.5662586092948914);
expect(input_values.data[1]).toBeCloseTo(-1.0300861597061157);
expect(input_values.data[129]).toBeCloseTo(-1.084834098815918);
// padded values
expect(input_values.data[1025]).toBeCloseTo(0.46703237295150757);
expect(input_values.data[2049]).toBeCloseTo(0.46703237295150757);
expect(input_values.data[10000]).toBeCloseTo(0.46703237295150757);
}
}, MAX_TEST_EXECUTION_TIME);
it('ClapFeatureExtractor', async () => {
const audio = await audioPromise;
const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused');
{ // truncation
// Since truncation uses a random strategy, we override
// Math.random to ensure that the test is deterministic
const originalRandom = Math.random;
Math.random = () => 0.5;
let long_audio = new Float32Array(500000);
long_audio.set(audio);
long_audio.set(audio, long_audio.length - audio.length);
const { input_features } = await processor(long_audio);
compare(input_features.dims, [1, 1, 1001, 64]);
expect(avg(input_features.data)).toBeCloseTo(-37.94569396972656);
expect(input_features.data[0]).toBeCloseTo(-53.32647705078125);
expect(input_features.data[1]).toBeCloseTo(-47.76755142211914);
expect(input_features.data[65]).toBeCloseTo(-36.32261276245117);
expect(input_features.data[1002]).toBeCloseTo(-28.0314884185791);
expect(input_features.data[10000]).toBeCloseTo(-21.905902862548828);
expect(input_features.data[60000]).toBeCloseTo(-14.877863883972168);
expect(input_features.data[64062]).toBeCloseTo(-37.9784049987793);
expect(input_features.data[64063]).toBeCloseTo(-37.73963928222656);
// Reset Math.random
Math.random = originalRandom;
}
{ // padding
const { input_features } = await processor(audio);
compare(input_features.dims, [1, 1, 1001, 64]);
expect(avg(input_features.data)).toBeCloseTo(-34.99049377441406);
expect(input_features.data[0]).toBeCloseTo(-21.32573890686035);
expect(input_features.data[1]).toBeCloseTo(-26.168411254882812);
expect(input_features.data[65]).toBeCloseTo(-29.716018676757812);
expect(input_features.data[1002]).toBeCloseTo(-32.16273498535156);
expect(input_features.data[10000]).toBeCloseTo(-19.9283390045166);
// padded values
expect(input_features.data[60000]).toBeCloseTo(-100.0);
expect(input_features.data[64062]).toBeCloseTo(-100.0);
expect(input_features.data[64063]).toBeCloseTo(-100.0);
}
}, MAX_TEST_EXECUTION_TIME);
});
});

View File

@ -1,8 +1,8 @@
import { AutoProcessor } from '../src/transformers.js';
import { getMelFilters } from '../src/utils/audio.js';
import { mel_filter_bank } from '../src/utils/audio.js';
import { MAX_TEST_EXECUTION_TIME, m } from './init.js';
import { MAX_TEST_EXECUTION_TIME } from './init.js';
describe('Utilities', () => {
@ -11,28 +11,32 @@ describe('Utilities', () => {
it('should calculate MEL filters', async () => {
// NOTE: Uses official HF implementation as reference:
let processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en');
let config = processor.feature_extractor.config;
let maxdiff = 0;
const processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en');
const config = processor.feature_extractor.config;
// True MEL filters
let original_mel_filters = config.mel_filters;
const original_mel_filters = config.mel_filters;
// Calculated MEL filters
let calculated_mel_filters = getMelFilters(config.sampling_rate, config.n_fft, config.feature_size);
const calculated_mel_filters = mel_filter_bank(
Math.floor(1 + config.n_fft / 2), // num_frequency_bins
config.feature_size, // num_mel_filters
0.0, // min_frequency
8000.0, // max_frequency
config.sampling_rate, // sampling_rate
"slaney", // norm
"slaney", // mel_scale
);
for (let i = 0; i < original_mel_filters.length; ++i) {
for (let j = 0; j < original_mel_filters[i].length; ++j) {
const expected = original_mel_filters[i][j];
const calculated = calculated_mel_filters[i][j];
const diff = Math.abs(expected - calculated);
maxdiff = Math.max(maxdiff, diff);
}
}
const original = original_mel_filters.flat();
const calculated = calculated_mel_filters.flat();
// Compute max difference
const maxdiff = original.reduce((maxdiff, _, i) => {
const diff = Math.abs(original[i] - calculated[i]);
return Math.max(maxdiff, diff);
}, -Infinity);
expect(maxdiff).toBeGreaterThanOrEqual(0);
expect(maxdiff).toBeLessThan(1e-6);
}, MAX_TEST_EXECUTION_TIME);