transformers.js/scripts/convert.py

546 lines
18 KiB
Python
Raw Permalink Normal View History

2023-02-16 06:11:49 +08:00
import json
2023-02-16 06:11:49 +08:00
import os
import shutil
2023-02-16 06:11:49 +08:00
from dataclasses import dataclass, field
from typing import Optional, Set
2023-04-22 05:20:54 +08:00
from tqdm import tqdm
2023-02-16 06:11:49 +08:00
from transformers import (
AutoConfig,
AutoTokenizer,
HfArgumentParser
)
2023-02-16 06:11:49 +08:00
import onnx
from optimum.exporters.onnx import main_export, export_models
2023-02-16 06:11:49 +08:00
from optimum.exporters.tasks import TasksManager
from onnxruntime.quantization import (
quantize_dynamic,
QuantType
)
2023-02-16 06:11:49 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
DEFAULT_QUANTIZE_PARAMS = {
'per_channel': True,
'reduce_range': True,
}
MODEL_SPECIFIC_QUANTIZE_PARAMS = {
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
# Decoder-only models
'codegen': {
'per_channel': False,
'reduce_range': False,
},
'gpt2': {
'per_channel': False,
'reduce_range': False,
},
'gpt_bigcode': {
'per_channel': False,
'reduce_range': False,
},
'gptj': {
'per_channel': False,
'reduce_range': False,
},
'gpt-neo': {
'per_channel': False,
'reduce_range': False,
},
'gpt-neox': {
'per_channel': False,
'reduce_range': False,
},
'mpt': {
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
'per_channel': False,
'reduce_range': False,
},
'bloom': {
'per_channel': False,
'reduce_range': False,
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
},
'llama': {
'per_channel': False,
'reduce_range': False,
},
'opt': {
'per_channel': False,
'reduce_range': False,
},
'mistral': {
'per_channel': False,
'reduce_range': False,
},
'falcon': {
'per_channel': False,
'reduce_range': False,
},
'phi': {
'per_channel': False,
'reduce_range': False,
},
'qwen2': {
'per_channel': False,
'reduce_range': False,
},
'stablelm': {
'per_channel': False,
'reduce_range': False,
},
'starcoder2': {
'per_channel': False,
'reduce_range': False,
},
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
# Encoder-decoder models
'whisper': {
'per_channel': False,
'reduce_range': False,
},
'vision-encoder-decoder': {
'per_channel': False,
'reduce_range': False,
},
2024-02-19 20:57:06 +08:00
# Encoder-only models
'owlv2': {
'per_channel': False,
'reduce_range': False,
},
'wavlm': {
'per_channel': False,
'reduce_range': False,
},
'wav2vec2': {
'per_channel': False,
'reduce_range': False,
},
'unispeech': {
'per_channel': False,
'reduce_range': False,
},
'unispeech-sat': {
'per_channel': False,
'reduce_range': False,
},
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
}
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
MODELS_WITHOUT_TOKENIZERS = [
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
'wav2vec2',
'wav2vec2-bert',
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
'wavlm',
2023-12-13 00:32:16 +08:00
'hubert',
'unispeech',
'unispeech-sat',
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
]
2023-02-16 06:11:49 +08:00
@dataclass
class ConversionArguments:
"""
Arguments used for converting HuggingFace models to onnx.
"""
model_id: str = field(
metadata={
"help": "Model identifier"
}
)
tokenizer_id: str = field(
default=None,
metadata={
"help": "Tokenizer identifier (if different to `model_id`)"
}
)
2023-02-16 06:11:49 +08:00
quantize: bool = field(
default=False,
metadata={
"help": "Whether to quantize the model."
}
)
output_parent_dir: str = field(
default='./models/',
2023-02-16 06:11:49 +08:00
metadata={
"help": "Path where the converted model will be saved to."
}
)
task: Optional[str] = field(
default='auto',
2023-02-16 06:11:49 +08:00
metadata={
"help": (
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:"
f" {str(TasksManager.get_all_tasks())}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder."
2023-02-16 06:11:49 +08:00
)
}
)
opset: int = field(
default=None,
metadata={
"help": (
"If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used."
)
}
)
device: str = field(
default='cpu',
metadata={
2023-04-02 19:18:30 +08:00
"help": 'The device to use to do the export.'
2023-02-16 06:11:49 +08:00
}
)
skip_validation: bool = field(
default=False,
metadata={
"help": "Whether to skip validation of the converted model"
}
)
2023-02-16 06:11:49 +08:00
per_channel: bool = field(
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
default=None,
metadata={
"help": "Whether to quantize weights per channel"
}
)
reduce_range: bool = field(
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
default=None,
metadata={
"help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode"
}
)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
output_attentions: bool = field(
default=False,
metadata={
"help": "Whether to output attentions from the model. NOTE: This is only supported for whisper models right now."
}
)
split_modalities: bool = field(
default=False,
metadata={
"help": "Whether to split multimodal models. NOTE: This is only supported for CLIP models right now."
}
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": "Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories"
"you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository."
}
)
custom_onnx_configs: str = field(
default=None,
metadata={
"help": "Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users "
"that desire a finer-grained control on the export."
}
)
def get_operators(model: onnx.ModelProto) -> Set[str]:
operators = set()
2023-02-16 06:11:49 +08:00
def traverse_graph(graph):
for node in graph.node:
operators.add(node.op_type)
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
subgraph = attr.g
traverse_graph(subgraph)
traverse_graph(model.graph)
return operators
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
def quantize(model_names_or_paths, **quantize_kwargs):
2023-02-21 22:53:32 +08:00
"""
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
Uses unsigned ints for activation values, signed ints for weights, per
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
it is faster on most CPU architectures
Args:
onnx_model_path: Path to location the exported ONNX model is stored
Returns: The Path generated for the quantized
"""
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
quantize_config = dict(
**quantize_kwargs,
per_model_config={}
)
for model in tqdm(model_names_or_paths, desc='Quantizing'):
directory_path = os.path.dirname(model)
file_name_without_extension = os.path.splitext(
os.path.basename(model))[0]
2023-03-08 05:53:50 +08:00
# NOTE:
# As of 2023/04/20, the current latest version of onnxruntime-web is 1.14.0, and does not support INT8 weights for Conv layers.
# For this reason, we choose model weight types to ensure compatibility with onnxruntime-web.
#
# As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer.
# For more information, see:
# - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
# - https://github.com/microsoft/onnxruntime/issues/2339
loaded_model = onnx.load_model(model)
op_types = get_operators(loaded_model)
weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8
2023-03-08 05:53:50 +08:00
2023-02-21 22:53:32 +08:00
quantize_dynamic(
model_input=model,
model_output=os.path.join(
directory_path, f'{file_name_without_extension}_quantized.onnx'),
2023-03-08 05:53:50 +08:00
weight_type=weight_type,
2023-02-21 22:53:32 +08:00
optimize_model=False,
# TODO allow user to specify these
# op_types_to_quantize=['MatMul', 'Add', 'Conv'],
extra_options=dict(
EnableSubgraph=True
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
),
**quantize_kwargs
)
2023-02-21 22:53:32 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
quantize_config['per_model_config'][file_name_without_extension] = dict(
op_types=list(op_types),
weight_type=str(weight_type),
)
# Save quantization config
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
with open(os.path.join(directory_path, 'quantize_config.json'), 'w') as fp:
json.dump(quantize_config, fp, indent=4)
2023-02-16 06:11:49 +08:00
def main():
parser = HfArgumentParser(
(ConversionArguments, )
)
conv_args, = parser.parse_args_into_dataclasses()
model_id = conv_args.model_id
tokenizer_id = conv_args.tokenizer_id or model_id
2023-03-02 01:12:32 +08:00
output_model_folder = os.path.join(conv_args.output_parent_dir, model_id)
2023-02-16 06:11:49 +08:00
# Create output folder
os.makedirs(output_model_folder, exist_ok=True)
from_pretrained_kwargs = dict(
trust_remote_code=conv_args.trust_remote_code,
)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
# Saving the model config
config = AutoConfig.from_pretrained(model_id, **from_pretrained_kwargs)
custom_kwargs={}
if conv_args.custom_onnx_configs is not None:
if conv_args.task == 'auto':
raise Exception('`--task` must be set when exporting with `--custom_onnx_configs`')
custom_onnx_configs = json.loads(conv_args.custom_onnx_configs)
for key in custom_onnx_configs:
onnx_configs = TasksManager._SUPPORTED_MODEL_TYPE[custom_onnx_configs[key]]['onnx']
mapping = onnx_configs[conv_args.task]
custom_onnx_configs[key] = mapping.func(config, **mapping.keywords)
custom_kwargs['custom_onnx_configs'] = custom_onnx_configs
2023-02-16 06:11:49 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
tokenizer = None
try:
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, **from_pretrained_kwargs)
Add support for chat templates (#408) * Add basic support for chat templates * Cleanup * JSDoc improvements * Support conversion of user-defined functions * Cleanup * Fix function creation * Add unit tests for templates * Cleanup * Improve JSDoc * Add missing return types * Add chat templates docs to table of contents * Add support for logical negation * Fix nested logical negation * Add unit tests for logical operators * Add loop variables * Add support for `RuntimeValue` built-in functions * Add unit tests for string instance methods * Fix conversion of normal function to `FunctionValue` * Update object method unit tests * Save chat template to tokenizer_config.json during conversion * Fix `raise_exception` error * Add `!=` operator for booleans * Remember to increment loop index * Cleanup for loop evaluator * Use `is` helper function * Add support for text nodes i.e., non Jinja statements/expressions * Add auto-generated templating tests * Update unit tests * Remove unused function * Add default chat templates * Use repo with up-to-date tokenizer config * Temporarily disable zephyr test * Delete templates.test.js * Move Jinja functionality to `@huggingface/jinja` * Fix template cache type * Update chat template unit tests * Update `@huggingface/jinja` version * Fix default llama2 system prompt usage * Add unit test for llama2 w/o chat template set * Update jinja version * Update jinja version * Add unit test for user-defined chat templates Example from https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3 * Add `AddedToken` for improved tokenization * Add example usage for chat templates * Add 'first' Metaspace pretokenizer prepend scheme * Formatting * Update wav2vec2 converter special tokens whitespace split * Fix Metaspace pretokenizer split criteria * Update inputs of `PreTokenizerSequence` * Improve Metaspace pretokenizer * Update llama tokenizer tests * Improve handling of legacy llama tokenizer * Re-enable SPM tests * Add static tokenizer test cases * Add llama2 static tests * Allow user to override legacy tokenizer behaviour in `.from_pretrained` * Add legacy tokenizer unit tests * Bump jinja version to 0.1.0
2023-12-18 23:00:50 +08:00
# To avoid inserting all chat templates into tokenizers.js, we save the chat template
# to the tokenizer_config.json file, and load it when the tokenizer is loaded.
if getattr(tokenizer, 'chat_template', None) is None and \
getattr(tokenizer, 'use_default_system_prompt', False):
# No chat template specified, and we use the default
setattr(tokenizer, 'chat_template', tokenizer.default_chat_template)
except KeyError:
pass # No Tokenizer
2023-02-16 06:11:49 +08:00
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
except Exception as e:
if config.model_type not in MODELS_WITHOUT_TOKENIZERS:
raise e
core_export_kwargs = dict(
opset=conv_args.opset,
device=conv_args.device,
trust_remote_code=conv_args.trust_remote_code,
**custom_kwargs,
)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
export_kwargs = dict(
model_name_or_path=model_id,
output=output_model_folder,
task=conv_args.task,
do_validation=not conv_args.skip_validation,
library_name='transformers',
**core_export_kwargs,
)
2023-02-21 22:53:32 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
# Handle special cases
if config.model_type == 'marian':
from .extra.marian import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(model_id, tokenizer)
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
json.dump(tokenizer_json, fp, indent=4)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
elif config.model_type == 'esm':
from .extra.esm import generate_fast_tokenizer
fast_tokenizer = generate_fast_tokenizer(tokenizer)
fast_tokenizer.save(os.path.join(output_model_folder, 'tokenizer.json'))
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
elif config.model_type == 'whisper':
if conv_args.output_attentions:
from .extra.whisper import get_main_export_kwargs
export_kwargs.update(
**get_main_export_kwargs(config, "automatic-speech-recognition")
)
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
elif config.model_type in ('wav2vec2', 'wav2vec2-bert', 'hubert', 'unispeech' , 'unispeech-sat'):
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
if tokenizer is not None:
from .extra.wav2vec2 import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(tokenizer)
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
elif config.model_type == 'vits':
if tokenizer is not None:
from .extra.vits import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(tokenizer)
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
elif config.model_type == 'speecht5':
# TODO allow user to specify vocoder path
export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"}
if tokenizer is not None:
from .extra.speecht5 import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(tokenizer)
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
2024-02-19 20:57:06 +08:00
elif config.model_type in ('owlvit', 'owlv2'):
# Override default batch size to 1, needed because non-maximum suppression is performed for exporting.
# For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
export_kwargs['batch_size'] = 1
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
else:
pass # TODO
# Step 1. convert huggingface model to onnx
if not conv_args.split_modalities:
main_export(**export_kwargs)
else:
custom_export_kwargs = dict(
output_dir=output_model_folder,
**core_export_kwargs,
)
if config.model_type == 'clip':
# Handle special case for exporting text and vision models separately
from .extra.clip import CLIPTextModelWithProjectionOnnxConfig, CLIPVisionModelWithProjectionOnnxConfig
from transformers.models.clip import CLIPTextModelWithProjection, CLIPVisionModelWithProjection
text_model = CLIPTextModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs)
vision_model = CLIPVisionModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs)
export_models(
models_and_onnx_configs={
"text_model": (text_model, CLIPTextModelWithProjectionOnnxConfig(text_model.config)),
"vision_model": (vision_model, CLIPVisionModelWithProjectionOnnxConfig(vision_model.config)),
},
**custom_export_kwargs,
)
elif config.model_type == 'siglip':
# Handle special case for exporting text and vision models separately
from .extra.siglip import SiglipTextModelOnnxConfig, SiglipVisionModelOnnxConfig
from transformers.models.siglip import SiglipTextModel, SiglipVisionModel
text_model = SiglipTextModel.from_pretrained(model_id, **from_pretrained_kwargs)
vision_model = SiglipVisionModel.from_pretrained(model_id, **from_pretrained_kwargs)
export_models(
models_and_onnx_configs={
"text_model": (text_model, SiglipTextModelOnnxConfig(text_model.config)),
"vision_model": (vision_model, SiglipVisionModelOnnxConfig(vision_model.config)),
},
**custom_export_kwargs,
)
# TODO: Enable once https://github.com/huggingface/optimum/pull/1552 is merged
# elif config.model_type == 'clap':
# # Handle special case for exporting text and audio models separately
# from .extra.clap import ClapTextModelWithProjectionOnnxConfig, ClapAudioModelWithProjectionOnnxConfig
# from transformers.models.clap import ClapTextModelWithProjection, ClapAudioModelWithProjection
# text_model = ClapTextModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs)
# audio_model = ClapAudioModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs)
# export_models(
# models_and_onnx_configs={
# "text_model": (text_model, ClapTextModelWithProjectionOnnxConfig(text_model.config)),
# "audio_model": (audio_model, ClapAudioModelWithProjectionOnnxConfig(audio_model.config)),
# },
# **custom_export_kwargs,
# )
else:
raise Exception(f'Unable to export {config.model_type} model with `--split_modalities`.')
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
2023-02-16 06:11:49 +08:00
# Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size.
if conv_args.quantize:
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
# Update quantize config with model specific defaults
quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS.get(
config.model_type, DEFAULT_QUANTIZE_PARAMS)
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
# Update if user specified values
if conv_args.per_channel is not None:
quantize_config['per_channel'] = conv_args.per_channel
if conv_args.reduce_range is not None:
quantize_config['reduce_range'] = conv_args.reduce_range
quantize([
os.path.join(output_model_folder, x)
for x in os.listdir(output_model_folder)
if x.endswith('.onnx') and not x.endswith('_quantized.onnx')
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
], **quantize_config)
# Step 3. Move .onnx files to the 'onnx' subfolder
os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True)
for file in os.listdir(output_model_folder):
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
if file.endswith(('.onnx', '.onnx_data')):
shutil.move(os.path.join(output_model_folder, file),
os.path.join(output_model_folder, 'onnx', file))
2023-02-16 06:11:49 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
# Step 4. Update the generation config if necessary
if config.model_type == 'whisper':
from transformers import GenerationConfig
from .extra.whisper import get_alignment_heads
generation_config = GenerationConfig.from_pretrained(model_id, **from_pretrained_kwargs)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
generation_config.alignment_heads = get_alignment_heads(config)
generation_config.save_pretrained(output_model_folder)
2023-02-16 06:11:49 +08:00
2023-02-16 06:11:49 +08:00
if __name__ == '__main__':
main()