Add support for Segment Anything Model (#510)
* Update SamModel * Make `AutoModel.from_pretrained` work with SamModel * Add listed support for SAM (Segment Anything Model) * Update types of `calculateDimensions` * Throw error if reading image from tensor with dims.length != 3 * Make SamProcessor input points optional * Fix type errors * `let` -> `const` * `cat` -> `stack` * Expose `reshape_input_points` in `SamProcessor` * Add `input_labels` input parameter for SAM * Add `input_labels` to sam processor * Update SAM unit tests * Remove TODOs * Update JSDoc
This commit is contained in:
parent
4d1d4d3346
commit
cdcbfc125c
|
@ -328,6 +328,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
|
|||
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
|
||||
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
|
||||
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||
|
|
|
@ -63,6 +63,7 @@
|
|||
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
|
||||
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
|
||||
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||
|
|
|
@ -745,11 +745,20 @@ SUPPORTED_MODELS = {
|
|||
'distilroberta-base',
|
||||
],
|
||||
},
|
||||
# 'sam': [
|
||||
# 'facebook/sam-vit-base',
|
||||
# 'facebook/sam-vit-large',
|
||||
# 'facebook/sam-vit-huge',
|
||||
# ],
|
||||
'sam': {
|
||||
# Mask generation
|
||||
'mask-generation': [
|
||||
# SAM
|
||||
'facebook/sam-vit-base',
|
||||
'facebook/sam-vit-large',
|
||||
'facebook/sam-vit-huge',
|
||||
'wanglab/medsam-vit-base',
|
||||
|
||||
# SlimSAM
|
||||
'nielsr/slimsam-50-uniform',
|
||||
'nielsr/slimsam-77-uniform',
|
||||
],
|
||||
},
|
||||
'segformer': {
|
||||
# Image segmentation
|
||||
'image-segmentation': [
|
||||
|
|
143
src/models.js
143
src/models.js
|
@ -95,6 +95,7 @@ const MODEL_TYPES = {
|
|||
Seq2Seq: 2,
|
||||
Vision2Seq: 3,
|
||||
DecoderOnly: 4,
|
||||
MaskGeneration: 5,
|
||||
}
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
|
@ -771,6 +772,13 @@ export class PreTrainedModel extends Callable {
|
|||
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
|
||||
]);
|
||||
|
||||
} else if (modelType === MODEL_TYPES.MaskGeneration) {
|
||||
info = await Promise.all([
|
||||
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
|
||||
constructSession(pretrained_model_name_or_path, 'vision_encoder', options),
|
||||
constructSession(pretrained_model_name_or_path, 'prompt_encoder_mask_decoder', options),
|
||||
]);
|
||||
|
||||
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
|
||||
info = await Promise.all([
|
||||
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
|
||||
|
@ -4242,12 +4250,130 @@ export class YolosObjectDetectionOutput extends ModelOutput {
|
|||
|
||||
//////////////////////////////////////////////////
|
||||
export class SamPreTrainedModel extends PreTrainedModel { }
|
||||
|
||||
/**
|
||||
* Segment Anything Model (SAM) for generating segmentation masks, given an input image
|
||||
* and optional 2D location and bounding boxes.
|
||||
*
|
||||
* **Example:** Perform mask generation w/ `Xenova/sam-vit-base`.
|
||||
* ```javascript
|
||||
* import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers';
|
||||
*
|
||||
* const model = await SamModel.from_pretrained('Xenova/sam-vit-base');
|
||||
* const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base');
|
||||
*
|
||||
* const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png';
|
||||
* const raw_image = await RawImage.read(img_url);
|
||||
* const input_points = [[[450, 600]]] // 2D localization of a window
|
||||
*
|
||||
* const inputs = await processor(raw_image, input_points);
|
||||
* const outputs = await model(inputs);
|
||||
*
|
||||
* const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes);
|
||||
* // [
|
||||
* // Tensor {
|
||||
* // dims: [ 1, 3, 1764, 2646 ],
|
||||
* // type: 'bool',
|
||||
* // data: Uint8Array(14002632) [ ... ],
|
||||
* // size: 14002632
|
||||
* // }
|
||||
* // ]
|
||||
* const scores = outputs.iou_scores;
|
||||
* // Tensor {
|
||||
* // dims: [ 1, 1, 3 ],
|
||||
* // type: 'float32',
|
||||
* // data: Float32Array(3) [
|
||||
* // 0.8892380595207214,
|
||||
* // 0.9311248064041138,
|
||||
* // 0.983696699142456
|
||||
* // ],
|
||||
* // size: 3
|
||||
* // }
|
||||
* ```
|
||||
*/
|
||||
export class SamModel extends SamPreTrainedModel {
|
||||
/**
|
||||
* @param {Object} model_inputs
|
||||
* @param {Tensor} model_inputs.pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
|
||||
* @param {Tensor} model_inputs.input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. This is used by the prompt encoder to encode the prompt.
|
||||
* @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`.
|
||||
* Creates a new instance of the `SamModel` class.
|
||||
* @param {Object} config The configuration object specifying the hyperparameters and other model settings.
|
||||
* @param {Object} vision_encoder The ONNX session containing the vision encoder model.
|
||||
* @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model.
|
||||
*/
|
||||
constructor(config, vision_encoder, prompt_encoder_mask_decoder) {
|
||||
super(config, vision_encoder);
|
||||
this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute image embeddings and positional image embeddings, given the pixel values of an image.
|
||||
* @param {Object} model_inputs Object containing the model inputs.
|
||||
* @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `SamProcessor`.
|
||||
* @returns {Promise<{ image_embeddings: Tensor, image_positional_embeddings: Tensor }>} The image embeddings and positional image embeddings.
|
||||
*/
|
||||
async get_image_embeddings({ pixel_values }) {
|
||||
// in:
|
||||
// - pixel_values: tensor.float32[batch_size,3,1024,1024]
|
||||
//
|
||||
// out:
|
||||
// - image_embeddings: tensor.float32[batch_size,256,64,64]
|
||||
// - image_positional_embeddings: tensor.float32[batch_size,256,64,64]
|
||||
return await encoderForward(this, { pixel_values })
|
||||
}
|
||||
|
||||
/**
|
||||
* @typedef {Object} SamModelInputs Object containing the model inputs.
|
||||
* @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
|
||||
* These can be obtained using a `SamProcessor`.
|
||||
* @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`.
|
||||
* This is used by the prompt encoder to encode the prompt.
|
||||
* @property {Tensor} [input_labels] Input labels for the points, as a Tensor of shape `(batch_size, point_batch_size, num_points)`.
|
||||
* This is used by the prompt encoder to encode the prompt. There are 4 types of labels:
|
||||
* - `1`: the point is a point that contains the object of interest
|
||||
* - `0`: the point is a point that does not contain the object of interest
|
||||
* - `-1`: the point corresponds to the background
|
||||
* - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
|
||||
* @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder.
|
||||
* @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @param {SamModelInputs} model_inputs Object containing the model inputs.
|
||||
* @returns {Promise<Object>} The output of the model.
|
||||
*/
|
||||
async forward(model_inputs) {
|
||||
if (!model_inputs.image_embeddings || !model_inputs.image_positional_embeddings) {
|
||||
// Compute the image embeddings if they are missing
|
||||
model_inputs = {
|
||||
...model_inputs,
|
||||
...(await this.get_image_embeddings(model_inputs))
|
||||
}
|
||||
}
|
||||
|
||||
if (!model_inputs.input_labels) {
|
||||
// Set default input labels if they are missing
|
||||
const shape = model_inputs.input_points.dims.slice(0, -1);
|
||||
const numElements = shape.reduce((a, b) => a * b, 1);
|
||||
model_inputs.input_labels = new Tensor(
|
||||
'int64',
|
||||
new BigInt64Array(numElements).fill(1n),
|
||||
shape
|
||||
);
|
||||
}
|
||||
|
||||
// Returns:
|
||||
// - iou_scores: tensor.float32[batch_size,point_batch_size,3]
|
||||
// - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256]
|
||||
return await sessionRun(this.prompt_encoder_mask_decoder, {
|
||||
input_points: model_inputs.input_points,
|
||||
input_labels: model_inputs.input_labels,
|
||||
image_embeddings: model_inputs.image_embeddings,
|
||||
image_positional_embeddings: model_inputs.image_positional_embeddings,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the model with the provided inputs
|
||||
* @param {Object} model_inputs Model inputs
|
||||
* @returns {Promise<SamImageSegmentationOutput>} Object containing segmentation outputs
|
||||
*/
|
||||
async _call(model_inputs) {
|
||||
return new SamImageSegmentationOutput(await super._call(model_inputs));
|
||||
|
@ -5049,7 +5175,6 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|||
|
||||
['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],
|
||||
|
||||
['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly
|
||||
]);
|
||||
|
||||
const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
|
||||
|
@ -5290,7 +5415,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
|
|||
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
||||
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
||||
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
||||
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
||||
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
|
||||
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
||||
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
||||
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
|
||||
|
@ -5329,7 +5454,9 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
|
|||
* let model = await AutoModel.from_pretrained('bert-base-uncased');
|
||||
*/
|
||||
export class AutoModel extends PretrainedMixin {
|
||||
static MODEL_CLASS_MAPPINGS = [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_MAPPING_NAMES_DECODER_ONLY];
|
||||
/** @type {Map<string, Object>[]} */
|
||||
// @ts-ignore
|
||||
static MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map(x => x[0]);
|
||||
static BASE_IF_FAIL = true;
|
||||
}
|
||||
|
||||
|
@ -5493,7 +5620,7 @@ export class AutoModelForZeroShotObjectDetection extends PretrainedMixin {
|
|||
|
||||
|
||||
/**
|
||||
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
|
||||
* Helper class which is used to instantiate pretrained mask generation models with the `from_pretrained` function.
|
||||
* The chosen model class is determined by the type specified in the model config.
|
||||
*
|
||||
* @example
|
||||
|
|
|
@ -1122,24 +1122,23 @@ export class YolosFeatureExtractor extends ImageFeatureExtractor {
|
|||
* @property {Tensor} pixel_values
|
||||
* @property {HeightWidth[]} original_sizes
|
||||
* @property {HeightWidth[]} reshaped_input_sizes
|
||||
* @property {Tensor} input_points
|
||||
* @property {Tensor} [input_points]
|
||||
* @property {Tensor} [input_labels]
|
||||
*/
|
||||
|
||||
export class SamImageProcessor extends ImageFeatureExtractor {
|
||||
/**
|
||||
* @param {RawImage[]} images The image(s) to extract features from.
|
||||
* @param {*} input_points A 3D or 4D array, representing the input points provided by the user.
|
||||
* - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1.
|
||||
* - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`.
|
||||
* @returns {Promise<SamImageProcessorResult>}
|
||||
*/
|
||||
async _call(images, input_points) {
|
||||
let {
|
||||
pixel_values,
|
||||
original_sizes,
|
||||
reshaped_input_sizes,
|
||||
} = await super._call(images);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {any} input_points
|
||||
* @param {HeightWidth[]} original_sizes
|
||||
* @param {HeightWidth[]} reshaped_input_sizes
|
||||
* @returns {Tensor}
|
||||
*/
|
||||
reshape_input_points(input_points, original_sizes, reshaped_input_sizes) {
|
||||
|
||||
// Make deep copy to avoid altering user's input
|
||||
input_points = structuredClone(input_points);
|
||||
let shape = calculateDimensions(input_points);
|
||||
|
||||
// TODO: add support for 2D input_points
|
||||
|
@ -1170,26 +1169,68 @@ export class SamImageProcessor extends ImageFeatureExtractor {
|
|||
}
|
||||
}
|
||||
|
||||
let input_points_tensor = new Tensor(
|
||||
'int64',
|
||||
BigInt64Array.from(input_points.flat(Infinity)
|
||||
.map(x => BigInt(Math.round(x)))),
|
||||
return new Tensor(
|
||||
'float32',
|
||||
Float32Array.from(input_points.flat(Infinity)),
|
||||
shape
|
||||
)
|
||||
|
||||
// TODO: allowed to be floats?
|
||||
// let input_points_tensor = new Tensor(
|
||||
// 'float32',
|
||||
// Float32Array.from(input_points.flat(Infinity)),
|
||||
// shape
|
||||
// )
|
||||
}
|
||||
|
||||
return {
|
||||
pixel_values,
|
||||
original_sizes: original_sizes,
|
||||
reshaped_input_sizes: reshaped_input_sizes,
|
||||
input_points: input_points_tensor
|
||||
/**
|
||||
*
|
||||
* @param {any} input_labels
|
||||
* @param {Tensor} input_points
|
||||
* @returns {Tensor}
|
||||
*/
|
||||
add_input_labels(input_labels, input_points) {
|
||||
let shape = calculateDimensions(input_labels);
|
||||
if (shape.length === 2) {
|
||||
// Correct user's input
|
||||
shape = [1, ...shape];
|
||||
input_labels = [input_labels];
|
||||
} else if (shape.length !== 3) {
|
||||
throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.")
|
||||
}
|
||||
|
||||
if (shape.some((x, i) => x !== input_points.dims[i])) {
|
||||
throw Error(`The first ${shape.length} dimensions of 'input_points' and 'input_labels' must be the same.`)
|
||||
}
|
||||
return new Tensor(
|
||||
'int64',
|
||||
input_labels.flat(Infinity).map(BigInt),
|
||||
shape,
|
||||
)
|
||||
}
|
||||
/**
|
||||
* @param {any[]} images The URL(s) of the image(s) to extract features from.
|
||||
* @param {any} [input_points] A 3D or 4D array, representing the input points provided by the user.
|
||||
* - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1.
|
||||
* - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`.
|
||||
* @param {any} [input_labels] A 2D or 3D array, representing the input labels for the points, used by the prompt encoder to encode the prompt.
|
||||
* - 2D: `[point_batch_size, nb_points_per_image]`. In this case, `batch_size` is assumed to be 1.
|
||||
* - 3D: `[batch_size, point_batch_size, nb_points_per_image]`.
|
||||
* @returns {Promise<SamImageProcessorResult>}
|
||||
*/
|
||||
async _call(images, input_points = null, input_labels = null) {
|
||||
// TODO allow user to use preprocessed images
|
||||
/** @type {SamImageProcessorResult} */
|
||||
const processed = await super._call(images);
|
||||
|
||||
if (input_points) {
|
||||
processed.input_points = this.reshape_input_points(
|
||||
input_points, processed.original_sizes, processed.reshaped_input_sizes
|
||||
);
|
||||
}
|
||||
|
||||
if (input_labels) {
|
||||
if (!processed.input_points) {
|
||||
throw Error("`input_points` must be provided if `input_labels` are provided.")
|
||||
}
|
||||
processed.input_labels = this.add_input_labels(input_labels, processed.input_points);
|
||||
}
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1212,22 +1253,22 @@ export class SamImageProcessor extends ImageFeatureExtractor {
|
|||
} = {}) {
|
||||
// masks: [1, 1, 3, 256, 256]
|
||||
|
||||
let output_masks = [];
|
||||
const output_masks = [];
|
||||
|
||||
pad_size = pad_size ?? this.pad_size;
|
||||
|
||||
let target_image_size = [pad_size.height, pad_size.width];
|
||||
const target_image_size = [pad_size.height, pad_size.width];
|
||||
|
||||
for (let i = 0; i < original_sizes.length; ++i) {
|
||||
let original_size = original_sizes[i];
|
||||
let reshaped_input_size = reshaped_input_sizes[i];
|
||||
const original_size = original_sizes[i];
|
||||
const reshaped_input_size = reshaped_input_sizes[i];
|
||||
|
||||
let mask = masks[i]; // [b, c, h, w]
|
||||
const mask = masks[i]; // [b, c, h, w]
|
||||
|
||||
// TODO: improve
|
||||
let interpolated_masks = [];
|
||||
const interpolated_masks = [];
|
||||
for (let j = 0; j < mask.dims[0]; ++j) {
|
||||
let m = mask[j]; // 3d tensor
|
||||
const m = mask[j]; // 3d tensor
|
||||
|
||||
// Upscale mask to padded size
|
||||
let interpolated_mask = interpolate(m, target_image_size, 'bilinear', false);
|
||||
|
@ -1236,28 +1277,29 @@ export class SamImageProcessor extends ImageFeatureExtractor {
|
|||
interpolated_mask = interpolated_mask.slice(null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]);
|
||||
|
||||
// Downscale mask
|
||||
interpolated_mask = interpolate(mask, original_size, 'bilinear', false);
|
||||
interpolated_mask = interpolate(interpolated_mask, original_size, 'bilinear', false);
|
||||
|
||||
if (binarize) {
|
||||
const binarizedMaskData = new Uint8Array(interpolated_mask.data.length);
|
||||
for (let i = 0; i < interpolated_mask.data.length; ++i) {
|
||||
if (interpolated_mask.data[i] > mask_threshold) {
|
||||
binarizedMaskData[i] = 1;
|
||||
}
|
||||
}
|
||||
interpolated_mask = new Tensor(
|
||||
'bool',
|
||||
Array.from(interpolated_mask.data).map(x => x > mask_threshold),
|
||||
binarizedMaskData,
|
||||
interpolated_mask.dims
|
||||
)
|
||||
}
|
||||
|
||||
// add back batch dim for concat
|
||||
interpolated_mask.dims = [1, ...interpolated_mask.dims];
|
||||
|
||||
interpolated_masks.push(interpolated_mask);
|
||||
}
|
||||
|
||||
let concatenated = cat(interpolated_masks);
|
||||
output_masks.push(concatenated);
|
||||
output_masks.push(stack(interpolated_masks));
|
||||
}
|
||||
|
||||
return output_masks;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1732,12 +1774,10 @@ export class Processor extends Callable {
|
|||
|
||||
export class SamProcessor extends Processor {
|
||||
/**
|
||||
* @param {*} images
|
||||
* @param {*} input_points
|
||||
* @returns {Promise<any>}
|
||||
* @borrows SamImageProcessor#_call as _call
|
||||
*/
|
||||
async _call(images, input_points) {
|
||||
return await this.feature_extractor(images, input_points);
|
||||
async _call(...args) {
|
||||
return await this.feature_extractor(...args);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1747,6 +1787,13 @@ export class SamProcessor extends Processor {
|
|||
// @ts-ignore
|
||||
return this.feature_extractor.post_process_masks(...args);
|
||||
}
|
||||
/**
|
||||
* @borrows SamImageProcessor#reshape_input_points as reshape_input_points
|
||||
*/
|
||||
reshape_input_points(...args) {
|
||||
// @ts-ignore
|
||||
return this.feature_extractor.reshape_input_points(...args);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -109,8 +109,8 @@ export function exists(x) {
|
|||
/**
|
||||
* Calculates the dimensions of a nested array.
|
||||
*
|
||||
* @param {Array} arr The nested array to calculate dimensions for.
|
||||
* @returns {Array} An array containing the dimensions of the input array.
|
||||
* @param {any[]} arr The nested array to calculate dimensions for.
|
||||
* @returns {number[]} An array containing the dimensions of the input array.
|
||||
*/
|
||||
export function calculateDimensions(arr) {
|
||||
const dimensions = [];
|
||||
|
|
|
@ -169,6 +169,10 @@ export class RawImage {
|
|||
* @param {import('./tensor.js').Tensor} tensor
|
||||
*/
|
||||
static fromTensor(tensor, channel_format = 'CHW') {
|
||||
if (tensor.dims.length !== 3) {
|
||||
throw new Error(`Tensor should have 3 dimensions, but has ${tensor.dims.length} dimensions.`);
|
||||
}
|
||||
|
||||
if (channel_format === 'CHW') {
|
||||
tensor = tensor.transpose(1, 2, 0);
|
||||
} else if (channel_format === 'HWC') {
|
||||
|
|
|
@ -89,11 +89,33 @@ describe('Processors', () => {
|
|||
it(MODELS.sam, async () => {
|
||||
const processor = await AutoProcessor.from_pretrained(m(MODELS.sam))
|
||||
|
||||
{ // Basic test
|
||||
{ // without input points
|
||||
const image = await load_image(TEST_IMAGES.pattern_3x3);
|
||||
const { pixel_values } = await processor(image, [[[0, 0]]]);
|
||||
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);
|
||||
compare(pixel_values.dims, [1, 3, 1024, 1024]);
|
||||
compare(avg(pixel_values.data), -0.4505715670146813);
|
||||
|
||||
compare(original_sizes, [[3, 3]]);
|
||||
compare(reshaped_input_sizes, [[1024, 1024]]);
|
||||
}
|
||||
|
||||
{ // with input points
|
||||
const image = await load_image(TEST_IMAGES.pattern_3x3);
|
||||
const { original_sizes, reshaped_input_sizes, input_points } = await processor(image, [[[1, 2]]]);
|
||||
|
||||
compare(original_sizes, [[3, 3]]);
|
||||
compare(reshaped_input_sizes, [[1024, 1024]]);
|
||||
compare(input_points.tolist(), [[[[341.3333, 682.6667]]]]);
|
||||
}
|
||||
|
||||
{ // multiple points with labels
|
||||
const image = await load_image(TEST_IMAGES.pattern_3x3);
|
||||
const { original_sizes, reshaped_input_sizes, input_points, input_labels } = await processor(image, [[[1, 2], [2, 1]]], [[1, 0]]);
|
||||
|
||||
compare(original_sizes, [[3, 3]]);
|
||||
compare(reshaped_input_sizes, [[1024, 1024]]);
|
||||
compare(input_points.tolist(), [[[[341.3333, 682.6667], [682.6667, 341.3333]]]]);
|
||||
compare(input_labels.tolist(), [[[1n, 0n]]]);
|
||||
}
|
||||
}, MAX_TEST_EXECUTION_TIME);
|
||||
|
||||
|
|
Loading…
Reference in New Issue