transformers.js/scripts/convert.py

417 lines
13 KiB
Python
Raw Normal View History

2023-02-16 06:11:49 +08:00
import json
2023-02-16 06:11:49 +08:00
import os
import shutil
2023-02-16 06:11:49 +08:00
from dataclasses import dataclass, field
from typing import Optional, Set
2023-04-22 05:20:54 +08:00
from tqdm import tqdm
2023-02-16 06:11:49 +08:00
from transformers import (
AutoConfig,
AutoTokenizer,
HfArgumentParser
)
2023-02-16 06:11:49 +08:00
import onnx
from optimum.exporters.onnx import main_export, export_models
2023-02-16 06:11:49 +08:00
from optimum.exporters.tasks import TasksManager
from onnxruntime.quantization import (
quantize_dynamic,
QuantType
)
2023-02-16 06:11:49 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
DEFAULT_QUANTIZE_PARAMS = {
'per_channel': True,
'reduce_range': True,
}
MODEL_SPECIFIC_QUANTIZE_PARAMS = {
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
# Decoder-only models
'codegen': {
'per_channel': False,
'reduce_range': False,
},
'gpt2': {
'per_channel': False,
'reduce_range': False,
},
'gpt_bigcode': {
'per_channel': False,
'reduce_range': False,
},
'gptj': {
'per_channel': False,
'reduce_range': False,
},
'gpt-neo': {
'per_channel': False,
'reduce_range': False,
},
'gpt-neox': {
'per_channel': False,
'reduce_range': False,
},
'mpt': {
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
'per_channel': False,
'reduce_range': False,
},
'bloom': {
'per_channel': False,
'reduce_range': False,
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
},
'llama': {
'per_channel': False,
'reduce_range': False,
},
'opt': {
'per_channel': False,
'reduce_range': False,
},
'mistral': {
'per_channel': False,
'reduce_range': False,
},
'falcon': {
'per_channel': False,
'reduce_range': False,
},
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
# Encoder-decoder models
'whisper': {
'per_channel': False,
'reduce_range': False,
},
'vision-encoder-decoder': {
'per_channel': False,
'reduce_range': False,
},
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
}
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
MODELS_WITHOUT_TOKENIZERS = [
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
'wav2vec2',
'wavlm',
2023-12-13 00:32:16 +08:00
'hubert',
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
]
2023-02-16 06:11:49 +08:00
@dataclass
class ConversionArguments:
"""
Arguments used for converting HuggingFace models to onnx.
"""
model_id: str = field(
metadata={
"help": "Model identifier"
}
)
quantize: bool = field(
default=False,
metadata={
"help": "Whether to quantize the model."
}
)
output_parent_dir: str = field(
default='./models/',
2023-02-16 06:11:49 +08:00
metadata={
"help": "Path where the converted model will be saved to."
}
)
task: Optional[str] = field(
default='auto',
2023-02-16 06:11:49 +08:00
metadata={
"help": (
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:"
f" {str(TasksManager.get_all_tasks())}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder."
2023-02-16 06:11:49 +08:00
)
}
)
opset: int = field(
default=None,
metadata={
"help": (
"If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used."
)
}
)
device: str = field(
default='cpu',
metadata={
2023-04-02 19:18:30 +08:00
"help": 'The device to use to do the export.'
2023-02-16 06:11:49 +08:00
}
)
skip_validation: bool = field(
default=False,
metadata={
"help": "Whether to skip validation of the converted model"
}
)
2023-02-16 06:11:49 +08:00
per_channel: bool = field(
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
default=None,
metadata={
"help": "Whether to quantize weights per channel"
}
)
reduce_range: bool = field(
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
default=None,
metadata={
"help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode"
}
)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
output_attentions: bool = field(
default=False,
metadata={
"help": "Whether to output attentions from the model. NOTE: This is only supported for whisper models right now."
}
)
split_modalities: bool = field(
default=False,
metadata={
"help": "Whether to split multimodal models. NOTE: This is only supported for CLIP models right now."
}
)
def get_operators(model: onnx.ModelProto) -> Set[str]:
operators = set()
2023-02-16 06:11:49 +08:00
def traverse_graph(graph):
for node in graph.node:
operators.add(node.op_type)
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
subgraph = attr.g
traverse_graph(subgraph)
traverse_graph(model.graph)
return operators
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
def quantize(model_names_or_paths, **quantize_kwargs):
2023-02-21 22:53:32 +08:00
"""
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
Uses unsigned ints for activation values, signed ints for weights, per
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
it is faster on most CPU architectures
Args:
onnx_model_path: Path to location the exported ONNX model is stored
Returns: The Path generated for the quantized
"""
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
quantize_config = dict(
**quantize_kwargs,
per_model_config={}
)
for model in tqdm(model_names_or_paths, desc='Quantizing'):
directory_path = os.path.dirname(model)
file_name_without_extension = os.path.splitext(
os.path.basename(model))[0]
2023-03-08 05:53:50 +08:00
# NOTE:
# As of 2023/04/20, the current latest version of onnxruntime-web is 1.14.0, and does not support INT8 weights for Conv layers.
# For this reason, we choose model weight types to ensure compatibility with onnxruntime-web.
#
# As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer.
# For more information, see:
# - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
# - https://github.com/microsoft/onnxruntime/issues/2339
loaded_model = onnx.load_model(model)
op_types = get_operators(loaded_model)
weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8
2023-03-08 05:53:50 +08:00
2023-02-21 22:53:32 +08:00
quantize_dynamic(
model_input=model,
model_output=os.path.join(
directory_path, f'{file_name_without_extension}_quantized.onnx'),
2023-03-08 05:53:50 +08:00
weight_type=weight_type,
2023-02-21 22:53:32 +08:00
optimize_model=False,
# TODO allow user to specify these
# op_types_to_quantize=['MatMul', 'Add', 'Conv'],
extra_options=dict(
EnableSubgraph=True
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
),
**quantize_kwargs
)
2023-02-21 22:53:32 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
quantize_config['per_model_config'][file_name_without_extension] = dict(
op_types=list(op_types),
weight_type=str(weight_type),
)
# Save quantization config
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
with open(os.path.join(directory_path, 'quantize_config.json'), 'w') as fp:
json.dump(quantize_config, fp, indent=4)
2023-02-16 06:11:49 +08:00
def main():
parser = HfArgumentParser(
(ConversionArguments, )
)
conv_args, = parser.parse_args_into_dataclasses()
model_id = conv_args.model_id
2023-03-02 01:12:32 +08:00
output_model_folder = os.path.join(conv_args.output_parent_dir, model_id)
2023-02-16 06:11:49 +08:00
# Create output folder
os.makedirs(output_model_folder, exist_ok=True)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
# Saving the model config
config = AutoConfig.from_pretrained(model_id)
2023-02-16 06:11:49 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
tokenizer = None
try:
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
except KeyError:
pass # No Tokenizer
2023-02-16 06:11:49 +08:00
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
except Exception as e:
if config.model_type not in MODELS_WITHOUT_TOKENIZERS:
raise e
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
export_kwargs = dict(
model_name_or_path=model_id,
output=output_model_folder,
task=conv_args.task,
2023-05-31 21:13:19 +08:00
opset=conv_args.opset,
device=conv_args.device,
do_validation=not conv_args.skip_validation,
)
2023-02-21 22:53:32 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
# Handle special cases
if config.model_type == 'marian':
from .extra.marian import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(model_id, tokenizer)
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
json.dump(tokenizer_json, fp, indent=4)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
elif config.model_type == 'whisper':
if conv_args.output_attentions:
from .extra.whisper import get_main_export_kwargs
export_kwargs.update(
**get_main_export_kwargs(config, "automatic-speech-recognition")
)
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
2023-12-13 00:32:16 +08:00
elif config.model_type in ('wav2vec2', 'hubert'):
[WIP] Add MMS and Wav2Vec2 models (Closes #209) (#220) * Add example `wav2vec2` models * Add support for `CTCDecoder` and `Wav2Vec2CTCTokenizer` * Generate tokenizer.json files for wav2vec2 models * Fix wav2vec2 custom tokenizer generation * Implement wav2vec2 audio-speech-recognition * Add `Wav2Vec2` as a supported architecture * Update README.md * Update generate_tests.py * Ignore invalid tests * Update supported wav2vec2 models * Update supported_models.py * Simplify pipeline construction * Implement basic audio classification pipeline * Update default topk value for audio classification pipeline * Add example usage for the audio classification pipeline * Move `loadAudio` to utils file * Add audio classification unit test * Add wav2vec2 ASR unit test * Improve generated wav2vec2 tokenizer json * Update supported_models.py * Allow `added_tokens_regex` to be null * Support exporting mms vocabs * Supported nested vocabularies * Update supported tasks and models * Add warnings to ignore language and task for wav2vec2 models Will add in future * Mark internal methods as private * Add typing to audio variable * Update node-audio-processing.mdx * Move node-audio-processing to guides * Update table of contents * Add example code for performing feature extraction w/ `Wav2Vec2Model` NOTE: feature extraction of MMS models is currently broken in the python library, but it works correctly here. See https://github.com/huggingface/transformers/issues/25485 for more info * Refactor `Pipeline` class params * Fix `pipeline` function * Fix typo in `pipeline` JSDoc * Fix second typo
2023-08-15 04:18:44 +08:00
if tokenizer is not None:
from .extra.wav2vec2 import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(tokenizer)
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
elif config.model_type == 'speecht5':
# TODO allow user to specify vocoder path
export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"}
if tokenizer is not None:
from .extra.speecht5 import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(tokenizer)
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
elif config.model_type == 'owlvit':
# Override default batch size to 1, needed because non-maximum suppression is performed for exporting.
# For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
export_kwargs['batch_size'] = 1
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
else:
pass # TODO
# Step 1. convert huggingface model to onnx
if config.model_type == 'clip' and conv_args.split_modalities:
# Handle special case for exporting text and vision models separately
from .extra.clip import CLIPTextModelWithProjectionOnnxConfig, CLIPVisionModelWithProjectionOnnxConfig
from transformers.models.clip import CLIPTextModelWithProjection, CLIPVisionModelWithProjection
text_model = CLIPTextModelWithProjection.from_pretrained(model_id)
vision_model = CLIPVisionModelWithProjection.from_pretrained(model_id)
export_models(
models_and_onnx_configs={
"text_model": (text_model, CLIPTextModelWithProjectionOnnxConfig(text_model.config)),
"vision_model": (vision_model, CLIPVisionModelWithProjectionOnnxConfig(vision_model.config)),
},
output_dir=output_model_folder,
opset=conv_args.opset,
device=conv_args.device,
)
# TODO: Enable once https://github.com/huggingface/optimum/pull/1552 is merged
# elif config.model_type == 'clap' and conv_args.split_modalities:
# # Handle special case for exporting text and audio models separately
# from .extra.clap import ClapTextModelWithProjectionOnnxConfig, ClapAudioModelWithProjectionOnnxConfig
# from transformers.models.clap import ClapTextModelWithProjection, ClapAudioModelWithProjection
# text_model = ClapTextModelWithProjection.from_pretrained(model_id)
# audio_model = ClapAudioModelWithProjection.from_pretrained(model_id)
# export_models(
# models_and_onnx_configs={
# "text_model": (text_model, ClapTextModelWithProjectionOnnxConfig(text_model.config)),
# "audio_model": (audio_model, ClapAudioModelWithProjectionOnnxConfig(audio_model.config)),
# },
# output_dir=output_model_folder,
# opset=conv_args.opset,
# device=conv_args.device,
# )
else:
main_export(**export_kwargs)
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
2023-02-16 06:11:49 +08:00
# Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size.
if conv_args.quantize:
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
# Update quantize config with model specific defaults
quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS.get(
config.model_type, DEFAULT_QUANTIZE_PARAMS)
New models and refactoring (#276) * Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
2023-09-08 21:17:05 +08:00
# Update if user specified values
if conv_args.per_channel is not None:
quantize_config['per_channel'] = conv_args.per_channel
if conv_args.reduce_range is not None:
quantize_config['reduce_range'] = conv_args.reduce_range
quantize([
os.path.join(output_model_folder, x)
for x in os.listdir(output_model_folder)
if x.endswith('.onnx') and not x.endswith('_quantized.onnx')
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
], **quantize_config)
# Step 3. Move .onnx files to the 'onnx' subfolder
os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True)
for file in os.listdir(output_model_folder):
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
if file.endswith(('.onnx', '.onnx_data')):
shutil.move(os.path.join(output_model_folder, file),
os.path.join(output_model_folder, 'onnx', file))
2023-02-16 06:11:49 +08:00
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
# Step 4. Update the generation config if necessary
if config.model_type == 'whisper':
from transformers import GenerationConfig
from .extra.whisper import get_alignment_heads
generation_config = GenerationConfig.from_pretrained(model_id)
generation_config.alignment_heads = get_alignment_heads(config)
generation_config.save_pretrained(output_model_folder)
2023-02-16 06:11:49 +08:00
2023-02-16 06:11:49 +08:00
if __name__ == '__main__':
main()