Add support for SigLIP models

This commit is contained in:
Joshua Lochner 2023-12-24 02:22:40 +02:00
parent 6f05572854
commit 0123984132
9 changed files with 215 additions and 12 deletions

View File

@ -309,6 +309,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.

View File

@ -50,6 +50,7 @@
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.

View File

@ -353,6 +353,23 @@ def main():
device=conv_args.device,
)
elif config.model_type == 'siglip' and conv_args.split_modalities:
# Handle special case for exporting text and vision models separately
from .extra.siglip import SiglipTextModelOnnxConfig, SiglipVisionModelOnnxConfig
from transformers.models.siglip import SiglipTextModel, SiglipVisionModel
text_model = SiglipTextModel.from_pretrained(model_id)
vision_model = SiglipVisionModel.from_pretrained(model_id)
export_models(
models_and_onnx_configs={
"text_model": (text_model, SiglipTextModelOnnxConfig(text_model.config)),
"vision_model": (vision_model, SiglipVisionModelOnnxConfig(vision_model.config)),
},
output_dir=output_model_folder,
opset=conv_args.opset,
device=conv_args.device,
)
else:
main_export(**export_kwargs)

33
scripts/extra/siglip.py Normal file
View File

@ -0,0 +1,33 @@
# Support exporting vision and text models separately:
# Adapted from https://github.com/huggingface/optimum/issues/1186#issuecomment-1637641760
from optimum.exporters.onnx.model_configs import SiglipTextOnnxConfig, ViTOnnxConfig
from typing import Dict
class SiglipVisionOnnxConfig(ViTOnnxConfig):
pass
class SiglipTextModelOnnxConfig(SiglipTextOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
if framework == "pt":
import torch
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int64)
return dummy_inputs
class SiglipVisionModelOnnxConfig(SiglipVisionOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"last_hidden_state": {0: "batch_size"},
"pooler_output": {0: "batch_size"},
}

View File

@ -445,6 +445,12 @@ SUPPORTED_MODELS = {
# 'facebook/sam-vit-large',
# 'facebook/sam-vit-huge',
# ],
'siglip': [
# Zero-shot image classification and feature extraction
# (with and without `--split_modalities`)
# NOTE: requires --opset 13
'nielsr/siglip-base-patch16-224',
],
'speecht5': [
# Text-to-speech
'microsoft/speecht5_tts',

View File

@ -2860,7 +2860,128 @@ export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
return super.from_pretrained(pretrained_model_name_or_path, options);
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// SigLIP models
export class SiglipPreTrainedModel extends PreTrainedModel { }
/**
* SigLIP Text and Vision Model with a projection layers on top
*
* **Example:** Perform zero-shot image classification with a `SiglipModel`.
*
* ```javascript
* import { AutoTokenizer, AutoProcessor, SiglipModel, RawImage } from '@xenova/transformers';
*
* // Load tokenizer, processor, and model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/siglip-base-patch16-224');
* const processor = await AutoProcessor.from_pretrained('Xenova/siglip-base-patch16-224');
* const model = await SiglipModel.from_pretrained('Xenova/siglip-base-patch16-224');
*
* // Run tokenization
* const texts = ['a photo of 2 cats', 'a photo of 2 dogs'];
* const text_inputs = tokenizer(texts, { padding: 'max_length', truncation: true });
*
* // Read image and run processor
* const image = await RawImage.read('http://images.cocodataset.org/val2017/000000039769.jpg');
* const image_inputs = await processor(image);
*
* // Run model with both text and pixel inputs
* const output = await model({ ...text_inputs, ...image_inputs });
* // {
* // logits_per_image: Tensor {
* // dims: [ 1, 2 ],
* // data: Float32Array(2) [ -1.6019744873046875, -10.720091819763184 ],
* // },
* // logits_per_text: Tensor {
* // dims: [ 2, 1 ],
* // data: Float32Array(2) [ -1.6019744873046875, -10.720091819763184 ],
* // },
* // text_embeds: Tensor {
* // dims: [ 2, 768 ],
* // data: Float32Array(1536) [ ... ],
* // },
* // image_embeds: Tensor {
* // dims: [ 1, 768 ],
* // data: Float32Array(768) [ ... ],
* // }
* // }
* ```
*/
export class SiglipModel extends SiglipPreTrainedModel { }
/**
* The text model from SigLIP without any head or projection on top.
*
* **Example:** Compute text embeddings with `SiglipTextModel`.
*
* ```javascript
* import { AutoTokenizer, SiglipTextModel } from '@xenova/transformers';
*
* // Load tokenizer and text model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/siglip-base-patch16-224');
* const text_model = await SiglipTextModel.from_pretrained('Xenova/siglip-base-patch16-224');
*
* // Run tokenization
* const texts = ['a photo of 2 cats', 'a photo of 2 dogs'];
* const text_inputs = tokenizer(texts, { padding: 'max_length', truncation: true });
*
* // Compute embeddings
* const { pooler_output } = await text_model(text_inputs);
* // Tensor {
* // dims: [ 2, 768 ],
* // type: 'float32',
* // data: Float32Array(1536) [ ... ],
* // size: 1536
* // }
* ```
*/
export class SiglipTextModel extends SiglipPreTrainedModel {
/** @type {PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
}
}
/**
* The vision model from SigLIP without any head or projection on top.
*
* **Example:** Compute vision embeddings with `SiglipVisionModel`.
*
* ```javascript
* import { AutoProcessor, SiglipVisionModel, RawImage} from '@xenova/transformers';
*
* // Load processor and vision model
* const processor = await AutoProcessor.from_pretrained('Xenova/siglip-base-patch16-224');
* const vision_model = await SiglipVisionModel.from_pretrained('Xenova/siglip-base-patch16-224');
*
* // Read image and run processor
* const image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg');
* const image_inputs = await processor(image);
*
* // Compute embeddings
* const { pooler_output } = await vision_model(image_inputs);
* // Tensor {
* // dims: [ 1, 768 ],
* // type: 'float32',
* // data: Float32Array(768) [ ... ],
* // size: 768
* // }
* ```
*/
export class SiglipVisionModel extends CLIPPreTrainedModel {
/** @type {PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@ -4147,6 +4268,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['xlm', ['XLMModel', XLMModel]],
['xlm-roberta', ['XLMRobertaModel', XLMRobertaModel]],
['clip', ['CLIPModel', CLIPModel]],
['siglip', ['SiglipModel', SiglipModel]],
['mobilebert', ['MobileBertModel', MobileBertModel]],
['squeezebert', ['SqueezeBertModel', SqueezeBertModel]],
['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]],
@ -4391,6 +4513,9 @@ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
const CUSTOM_MAPPING = [
['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection, MODEL_TYPES.EncoderOnly],
['SiglipTextModel', SiglipTextModel, MODEL_TYPES.EncoderOnly],
['SiglipVisionModel', SiglipVisionModel, MODEL_TYPES.EncoderOnly],
]
for (const [name, model, type] of CUSTOM_MAPPING) {
MODEL_TYPE_MAPPING.set(name, type);

View File

@ -1659,7 +1659,7 @@ export class ZeroShotImageClassificationPipeline extends Pipeline {
// Run tokenization
let text_inputs = this.tokenizer(texts, {
padding: true,
padding: this.tokenizer.padding ?? true,
truncation: true
});

View File

@ -193,8 +193,8 @@ export class ImageFeatureExtractor extends FeatureExtractor {
constructor(config) {
super(config);
this.image_mean = this.config.image_mean;
this.image_std = this.config.image_std;
this.image_mean = this.config.image_mean ?? this.config.mean;
this.image_std = this.config.image_std ?? this.config.std;
this.resample = this.config.resample ?? 2; // 2 => bilinear
this.do_rescale = this.config.do_rescale ?? true;
@ -378,6 +378,17 @@ export class ImageFeatureExtractor extends FeatureExtractor {
return [pixelData, imgDims];
}
/**
* Rescale the image' pixel values by `this.rescale_factor`.
* @param {Float32Array} pixelData The pixel data to rescale.
* @returns {void}
*/
rescale(pixelData) {
for (let i = 0; i < pixelData.length; ++i) {
pixelData[i] = this.rescale_factor * pixelData[i];
}
}
/**
* @typedef {object} PreprocessedImage
* @property {HeightWidth} original_size The original size of the image.
@ -507,9 +518,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
let imgDims = [image.height, image.width, image.channels];
if (this.do_rescale) {
for (let i = 0; i < pixelData.length; ++i) {
pixelData[i] = this.rescale_factor * pixelData[i];
}
this.rescale(pixelData);
}
if (this.do_normalize) {
@ -591,6 +600,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
export class DPTFeatureExtractor extends ImageFeatureExtractor { }
export class GLPNFeatureExtractor extends ImageFeatureExtractor { }
export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
export class SiglipImageProcessor extends ImageFeatureExtractor { }
export class ConvNextFeatureExtractor extends ImageFeatureExtractor { }
export class ConvNextImageProcessor extends ConvNextFeatureExtractor { } // NOTE extends ConvNextFeatureExtractor
export class ViTFeatureExtractor extends ImageFeatureExtractor { }
@ -1645,6 +1655,7 @@ export class AutoProcessor {
MobileViTFeatureExtractor,
OwlViTFeatureExtractor,
CLIPFeatureExtractor,
SiglipImageProcessor,
ConvNextFeatureExtractor,
ConvNextImageProcessor,
DPTFeatureExtractor,

View File

@ -2275,6 +2275,7 @@ export class PreTrainedTokenizer extends Callable {
this.do_lowercase_and_remove_accent = tokenizerConfig.do_lowercase_and_remove_accent ?? false;
// TODO allow user to change this
this.padding = null;
this.padding_side = 'right';
}
@ -2347,7 +2348,7 @@ export class PreTrainedTokenizer extends Callable {
* @param {string|string[]} text The text to tokenize.
* @param {Object} options An optional object containing the following properties:
* @param {string|string[]} [options.text_pair=null] Optional second sequence to be encoded. If set, must be the same type as text.
* @param {boolean} [options.padding=false] Whether to pad the input sequences.
* @param {boolean|'max_length'} [options.padding=false] Whether to pad the input sequences.
* @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model.
* @param {boolean} [options.truncation=null] Whether to truncate the input sequences.
* @param {number} [options.max_length=null] Maximum length of the returned list and optionally padding length.
@ -2408,11 +2409,13 @@ export class PreTrainedTokenizer extends Callable {
// At this point, tokens is batched: [batch_size, tokens]
// However, array may be jagged. So, we pad to max_length
let maxLengthOfBatch = max(tokens.map(x => x.length))[0];
// If null, we calculate max length from sequences
if (max_length === null) {
max_length = maxLengthOfBatch;
if (padding === 'max_length') {
max_length = this.model_max_length;
} else {
// Calculate max length from sequences
max_length = max(tokens.map(x => x.length))[0];
}
}
// Ensure it is less than model max length
@ -3780,7 +3783,12 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
}
export class CodeGenTokenizer extends PreTrainedTokenizer { }
export class CLIPTokenizer extends PreTrainedTokenizer { }
export class SiglipTokenizer extends PreTrainedTokenizer {
constructor(tokenizerJSON, tokenizerConfig) {
super(tokenizerJSON, tokenizerConfig);
this.padding = 'max_length';
}
}
/**
* @todo This model is not yet supported by Hugging Face's "fast" tokenizers library (https://github.com/huggingface/tokenizers).
@ -3872,6 +3880,7 @@ export class AutoTokenizer {
WhisperTokenizer,
CodeGenTokenizer,
CLIPTokenizer,
SiglipTokenizer,
MarianTokenizer,
BloomTokenizer,
NllbTokenizer,