Add `DeiT`, `Swin`, and `Yolos` vision models (#262)

* Add support `DeiT` models

* Add `Swin` models for image classification

* Add support for `yolos` models

* Add `YolosFeatureExtractor`

* Remove unused import

* Update list of supported models

* Remove SAM for now

Move SAM support to next release
This commit is contained in:
Joshua Lochner 2023-08-28 17:29:15 +02:00 committed by GitHub
parent f0573175fd
commit 09cf91abd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 195 additions and 74 deletions

View File

@ -260,6 +260,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
@ -276,11 +277,13 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever.
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
1. **[YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.

View File

@ -8,6 +8,7 @@
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
@ -24,10 +25,12 @@
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever.
1. **[XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
1. **[YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.

View File

@ -102,6 +102,12 @@ SUPPORTED_MODELS = {
'sileod/deberta-v3-base-tasksource-nli',
'sileod/deberta-v3-large-tasksource-nli',
],
'deit': [
'facebook/deit-tiny-distilled-patch16-224',
'facebook/deit-small-distilled-patch16-224',
'facebook/deit-base-distilled-patch16-224',
'facebook/deit-base-distilled-patch16-384',
],
'detr': [
'facebook/detr-resnet-50',
'facebook/detr-resnet-101',
@ -185,15 +191,27 @@ SUPPORTED_MODELS = {
'sentence-transformers/all-roberta-large-v1',
'julien-c/EsperBERTo-small-pos',
],
'sam': [
'facebook/sam-vit-base',
'facebook/sam-vit-large',
'facebook/sam-vit-huge',
],
# 'sam': [
# 'facebook/sam-vit-base',
# 'facebook/sam-vit-large',
# 'facebook/sam-vit-huge',
# ],
'squeezebert': [
'squeezebert/squeezebert-uncased',
'squeezebert/squeezebert-mnli',
],
'swin': [
'microsoft/swin-tiny-patch4-window7-224',
'microsoft/swin-base-patch4-window7-224',
'microsoft/swin-large-patch4-window12-384-in22k',
'microsoft/swin-base-patch4-window7-224-in22k',
'microsoft/swin-base-patch4-window12-384-in22k',
'microsoft/swin-base-patch4-window12-384',
'microsoft/swin-large-patch4-window7-224',
'microsoft/swin-small-patch4-window7-224',
'microsoft/swin-large-patch4-window7-224-in22k',
'microsoft/swin-large-patch4-window12-384',
],
't5': [
't5-small',
't5-base',
@ -260,6 +278,13 @@ SUPPORTED_MODELS = {
'openai/whisper-large',
'openai/whisper-large-v2',
],
'yolos': [
'hustvl/yolos-tiny',
'hustvl/yolos-small',
'hustvl/yolos-base',
'hustvl/yolos-small-dwr',
'hustvl/yolos-small-300',
]
}

View File

@ -2987,6 +2987,7 @@ export class LlamaForCausalLM extends LlamaPreTrainedModel {
//////////////////////////////////////////////////
export class ViTPreTrainedModel extends PreTrainedModel { }
export class ViTModel extends ViTPreTrainedModel { }
export class ViTForImageClassification extends ViTPreTrainedModel {
/**
* @param {any} model_inputs
@ -2999,6 +3000,7 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
//////////////////////////////////////////////////
export class MobileViTPreTrainedModel extends PreTrainedModel { }
export class MobileViTModel extends MobileViTPreTrainedModel { }
export class MobileViTForImageClassification extends MobileViTPreTrainedModel {
/**
* @param {any} model_inputs
@ -3015,6 +3017,7 @@ export class MobileViTForImageClassification extends MobileViTPreTrainedModel {
//////////////////////////////////////////////////
export class DetrPreTrainedModel extends PreTrainedModel { }
export class DetrModel extends DetrPreTrainedModel { }
export class DetrForObjectDetection extends DetrPreTrainedModel {
/**
* @param {any} model_inputs
@ -3066,6 +3069,60 @@ export class DetrSegmentationOutput extends ModelOutput {
//////////////////////////////////////////////////
//////////////////////////////////////////////////
export class DeiTPreTrainedModel extends PreTrainedModel { }
export class DeiTModel extends DeiTPreTrainedModel { }
export class DeiTForImageClassification extends DeiTPreTrainedModel {
/**
* @param {any} model_inputs
*/
async _call(model_inputs) {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
export class SwinPreTrainedModel extends PreTrainedModel { }
export class SwinModel extends SwinPreTrainedModel { }
export class SwinForImageClassification extends SwinPreTrainedModel {
/**
* @param {any} model_inputs
*/
async _call(model_inputs) {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
export class YolosPreTrainedModel extends PreTrainedModel { }
export class YolosModel extends YolosPreTrainedModel { }
export class YolosForObjectDetection extends YolosPreTrainedModel {
/**
* @param {any} model_inputs
*/
async _call(model_inputs) {
return new YolosObjectDetectionOutput(await super._call(model_inputs));
}
}
export class YolosObjectDetectionOutput extends ModelOutput {
/**
* @param {Object} output The output of the model.
* @param {Tensor} output.logits Classification logits (including no-object) for all queries.
* @param {Tensor} output.pred_boxes Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height).
* These values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding possible padding).
*/
constructor({ logits, pred_boxes }) {
super();
this.logits = logits;
this.pred_boxes = pred_boxes;
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
export class SamPreTrainedModel extends PreTrainedModel { }
export class SamModel extends SamPreTrainedModel {
@ -3400,6 +3457,13 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['squeezebert', SqueezeBertModel],
['wav2vec2', Wav2Vec2Model],
['detr', DetrModel],
['vit', ViTModel],
['mobilevit', MobileViTModel],
['deit', DeiTModel],
['swin', SwinModel],
['yolos', YolosModel],
['sam', SamModel], // TODO change to encoder-decoder when model is split correctly
]);
@ -3495,10 +3559,13 @@ const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([
['vit', ViTForImageClassification],
['mobilevit', MobileViTForImageClassification],
['deit', DeiTForImageClassification],
['swin', SwinForImageClassification],
]);
const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([
['detr', DetrForObjectDetection],
['yolos', YolosForObjectDetection],
]);
const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([

View File

@ -41,6 +41,85 @@ import { RawImage } from './utils/image.js';
import { getMelFilters } from './utils/audio.js';
// Helper functions
/**
* Converts bounding boxes from center format to corners format.
*
* @param {number[]} arr The coordinate for the center of the box and its width, height dimensions (center_x, center_y, width, height)
* @returns {number[]} The coodinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
*/
function center_to_corners_format([centerX, centerY, width, height]) {
return [
centerX - width / 2,
centerY - height / 2,
centerX + width / 2,
centerY + height / 2
];
}
/**
* Post-processes the outputs of the model (for object detection).
* @param {Object} outputs The outputs of the model that must be post-processed
* @param {Tensor} outputs.logits The logits
* @param {Tensor} outputs.pred_boxes The predicted boxes.
* @return {Object[]} An array of objects containing the post-processed outputs.
*/
function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null) {
const out_logits = outputs.logits;
const out_bbox = outputs.pred_boxes;
const [batch_size, num_boxes, num_classes] = out_logits.dims;
if (target_sizes !== null && target_sizes.length !== batch_size) {
throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
}
let toReturn = [];
for (let i = 0; i < batch_size; ++i) {
let target_size = target_sizes !== null ? target_sizes[i] : null;
let info = {
boxes: [],
classes: [],
scores: []
}
let logits = out_logits[i];
let bbox = out_bbox[i];
for (let j = 0; j < num_boxes; ++j) {
let logit = logits[j];
// Get most probable class
let maxIndex = max(logit.data)[1];
if (maxIndex === num_classes - 1) {
// This is the background class, skip it
continue;
}
// Compute softmax over classes
let probs = softmax(logit.data);
let score = probs[maxIndex];
if (score > threshold) {
// Some class has a high enough probability
/** @type {number[]} */
let box = bbox[j].data;
// convert to [x0, y0, x1, y1] format
box = center_to_corners_format(box)
if (target_size !== null) {
box = box.map((x, i) => x * target_size[(i + 1) % 2])
}
info.boxes.push(box);
info.classes.push(maxIndex);
info.scores.push(score);
}
}
toReturn.push(info);
}
return toReturn;
}
/**
* Base class for feature extractors.
*
@ -289,9 +368,8 @@ export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
*/
export class DetrFeatureExtractor extends ImageFeatureExtractor {
/**
* Calls the feature extraction process on an array of image
* URLs, preprocesses each image, and concatenates the resulting
* features into a single Tensor.
* Calls the feature extraction process on an array of image URLs, preprocesses
* each image, and concatenates the resulting features into a single Tensor.
* @param {any} urls The URL(s) of the image(s) to extract features from.
* @returns {Promise<Object>} An object containing the concatenated pixel values of the preprocessed images.
*/
@ -312,19 +390,6 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor {
return result;
}
/**
* @param {number[]} arr The URL(s) of the image(s) to extract features from.
* @returns {number[]} An object containing the concatenated pixel values of the preprocessed images.
*/
center_to_corners_format([centerX, centerY, width, height]) {
return [
centerX - width / 2,
centerY - height / 2,
centerX + width / 2,
centerY + height / 2
];
}
/**
* Post-processes the outputs of the model (for object detection).
* @param {Object} outputs The outputs of the model that must be post-processed
@ -332,59 +397,10 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor {
* @param {Tensor} outputs.pred_boxes The predicted boxes.
* @return {Object[]} An array of objects containing the post-processed outputs.
*/
post_process_object_detection(outputs, threshold = 0.5, target_sizes = null) {
const out_logits = outputs.logits;
const out_bbox = outputs.pred_boxes;
const [batch_size, num_boxes, num_classes] = out_logits.dims;
if (target_sizes !== null && target_sizes.length !== batch_size) {
throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
}
let toReturn = [];
for (let i = 0; i < batch_size; ++i) {
let target_size = target_sizes !== null ? target_sizes[i] : null;
let info = {
boxes: [],
classes: [],
scores: []
}
let logits = out_logits[i];
let bbox = out_bbox[i];
for (let j = 0; j < num_boxes; ++j) {
let logit = logits[j];
// Get most probable class
let maxIndex = max(logit.data)[1];
if (maxIndex === num_classes - 1) {
// This is the background class, skip it
continue;
}
// Compute softmax over classes
let probs = softmax(logit.data);
let score = probs[maxIndex];
if (score > threshold) {
// Some class has a high enough probability
/** @type {number[]} */
let box = bbox[j].data;
// convert to [x0, y0, x1, y1] format
box = this.center_to_corners_format(box)
if (target_size !== null) {
box = box.map((x, i) => x * target_size[(i + 1) % 2])
}
info.boxes.push(box);
info.classes.push(maxIndex);
info.scores.push(score);
}
}
toReturn.push(info);
}
return toReturn;
/** @type {post_process_object_detection} */
post_process_object_detection(...args) {
return post_process_object_detection(...args);
}
/**
@ -664,6 +680,13 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor {
}
}
export class YolosFeatureExtractor extends ImageFeatureExtractor {
/** @type {post_process_object_detection} */
post_process_object_detection(...args) {
return post_process_object_detection(...args);
}
}
export class SamImageProcessor extends ImageFeatureExtractor {
async _call(images, input_points) {
let {
@ -1301,6 +1324,7 @@ export class AutoProcessor {
'MobileViTFeatureExtractor': MobileViTFeatureExtractor,
'DeiTFeatureExtractor': DeiTFeatureExtractor,
'DetrFeatureExtractor': DetrFeatureExtractor,
'YolosFeatureExtractor': YolosFeatureExtractor,
'SamImageProcessor': SamImageProcessor,
'Wav2Vec2FeatureExtractor': Wav2Vec2FeatureExtractor,

View File

@ -8,7 +8,6 @@
* @module utils/image
*/
import fs from 'fs';
import { isString } from './core.js';
import { getFile } from './hub.js';
import { env } from '../env.js';