Add support for Segment Anything Model (#510)

* Update SamModel

* Make `AutoModel.from_pretrained` work with SamModel

* Add listed support for SAM (Segment Anything Model)

* Update types of `calculateDimensions`

* Throw error if reading image from tensor with dims.length != 3

* Make SamProcessor input points optional

* Fix type errors

* `let` -> `const`

* `cat` -> `stack`

* Expose `reshape_input_points` in `SamProcessor`

* Add `input_labels` input parameter for SAM

* Add `input_labels` to sam processor

* Update SAM unit tests

* Remove TODOs

* Update JSDoc
This commit is contained in:
Joshua Lochner 2024-01-10 17:47:21 +02:00 committed by GitHub
parent 4d1d4d3346
commit cdcbfc125c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 277 additions and 66 deletions

View File

@ -328,6 +328,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.

View File

@ -63,6 +63,7 @@
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.

View File

@ -745,11 +745,20 @@ SUPPORTED_MODELS = {
'distilroberta-base',
],
},
# 'sam': [
# 'facebook/sam-vit-base',
# 'facebook/sam-vit-large',
# 'facebook/sam-vit-huge',
# ],
'sam': {
# Mask generation
'mask-generation': [
# SAM
'facebook/sam-vit-base',
'facebook/sam-vit-large',
'facebook/sam-vit-huge',
'wanglab/medsam-vit-base',
# SlimSAM
'nielsr/slimsam-50-uniform',
'nielsr/slimsam-77-uniform',
],
},
'segformer': {
# Image segmentation
'image-segmentation': [

View File

@ -95,6 +95,7 @@ const MODEL_TYPES = {
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
MaskGeneration: 5,
}
//////////////////////////////////////////////////
@ -771,6 +772,13 @@ export class PreTrainedModel extends Callable {
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);
} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, 'vision_encoder', options),
constructSession(pretrained_model_name_or_path, 'prompt_encoder_mask_decoder', options),
]);
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
@ -4242,12 +4250,130 @@ export class YolosObjectDetectionOutput extends ModelOutput {
//////////////////////////////////////////////////
export class SamPreTrainedModel extends PreTrainedModel { }
/**
* Segment Anything Model (SAM) for generating segmentation masks, given an input image
* and optional 2D location and bounding boxes.
*
* **Example:** Perform mask generation w/ `Xenova/sam-vit-base`.
* ```javascript
* import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers';
*
* const model = await SamModel.from_pretrained('Xenova/sam-vit-base');
* const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base');
*
* const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png';
* const raw_image = await RawImage.read(img_url);
* const input_points = [[[450, 600]]] // 2D localization of a window
*
* const inputs = await processor(raw_image, input_points);
* const outputs = await model(inputs);
*
* const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes);
* // [
* // Tensor {
* // dims: [ 1, 3, 1764, 2646 ],
* // type: 'bool',
* // data: Uint8Array(14002632) [ ... ],
* // size: 14002632
* // }
* // ]
* const scores = outputs.iou_scores;
* // Tensor {
* // dims: [ 1, 1, 3 ],
* // type: 'float32',
* // data: Float32Array(3) [
* // 0.8892380595207214,
* // 0.9311248064041138,
* // 0.983696699142456
* // ],
* // size: 3
* // }
* ```
*/
export class SamModel extends SamPreTrainedModel {
/**
* @param {Object} model_inputs
* @param {Tensor} model_inputs.pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
* @param {Tensor} model_inputs.input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. This is used by the prompt encoder to encode the prompt.
* @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`.
* Creates a new instance of the `SamModel` class.
* @param {Object} config The configuration object specifying the hyperparameters and other model settings.
* @param {Object} vision_encoder The ONNX session containing the vision encoder model.
* @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model.
*/
constructor(config, vision_encoder, prompt_encoder_mask_decoder) {
super(config, vision_encoder);
this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder;
}
/**
* Compute image embeddings and positional image embeddings, given the pixel values of an image.
* @param {Object} model_inputs Object containing the model inputs.
* @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `SamProcessor`.
* @returns {Promise<{ image_embeddings: Tensor, image_positional_embeddings: Tensor }>} The image embeddings and positional image embeddings.
*/
async get_image_embeddings({ pixel_values }) {
// in:
// - pixel_values: tensor.float32[batch_size,3,1024,1024]
//
// out:
// - image_embeddings: tensor.float32[batch_size,256,64,64]
// - image_positional_embeddings: tensor.float32[batch_size,256,64,64]
return await encoderForward(this, { pixel_values })
}
/**
* @typedef {Object} SamModelInputs Object containing the model inputs.
* @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
* These can be obtained using a `SamProcessor`.
* @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`.
* This is used by the prompt encoder to encode the prompt.
* @property {Tensor} [input_labels] Input labels for the points, as a Tensor of shape `(batch_size, point_batch_size, num_points)`.
* This is used by the prompt encoder to encode the prompt. There are 4 types of labels:
* - `1`: the point is a point that contains the object of interest
* - `0`: the point is a point that does not contain the object of interest
* - `-1`: the point corresponds to the background
* - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
* @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder.
* @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder.
*/
/**
* @param {SamModelInputs} model_inputs Object containing the model inputs.
* @returns {Promise<Object>} The output of the model.
*/
async forward(model_inputs) {
if (!model_inputs.image_embeddings || !model_inputs.image_positional_embeddings) {
// Compute the image embeddings if they are missing
model_inputs = {
...model_inputs,
...(await this.get_image_embeddings(model_inputs))
}
}
if (!model_inputs.input_labels) {
// Set default input labels if they are missing
const shape = model_inputs.input_points.dims.slice(0, -1);
const numElements = shape.reduce((a, b) => a * b, 1);
model_inputs.input_labels = new Tensor(
'int64',
new BigInt64Array(numElements).fill(1n),
shape
);
}
// Returns:
// - iou_scores: tensor.float32[batch_size,point_batch_size,3]
// - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256]
return await sessionRun(this.prompt_encoder_mask_decoder, {
input_points: model_inputs.input_points,
input_labels: model_inputs.input_labels,
image_embeddings: model_inputs.image_embeddings,
image_positional_embeddings: model_inputs.image_positional_embeddings,
});
}
/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Model inputs
* @returns {Promise<SamImageSegmentationOutput>} Object containing segmentation outputs
*/
async _call(model_inputs) {
return new SamImageSegmentationOutput(await super._call(model_inputs));
@ -5049,7 +5175,6 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],
['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly
]);
const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
@ -5290,7 +5415,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
@ -5329,7 +5454,9 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
* let model = await AutoModel.from_pretrained('bert-base-uncased');
*/
export class AutoModel extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_MAPPING_NAMES_DECODER_ONLY];
/** @type {Map<string, Object>[]} */
// @ts-ignore
static MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map(x => x[0]);
static BASE_IF_FAIL = true;
}
@ -5493,7 +5620,7 @@ export class AutoModelForZeroShotObjectDetection extends PretrainedMixin {
/**
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
* Helper class which is used to instantiate pretrained mask generation models with the `from_pretrained` function.
* The chosen model class is determined by the type specified in the model config.
*
* @example

View File

@ -1122,24 +1122,23 @@ export class YolosFeatureExtractor extends ImageFeatureExtractor {
* @property {Tensor} pixel_values
* @property {HeightWidth[]} original_sizes
* @property {HeightWidth[]} reshaped_input_sizes
* @property {Tensor} input_points
* @property {Tensor} [input_points]
* @property {Tensor} [input_labels]
*/
export class SamImageProcessor extends ImageFeatureExtractor {
/**
* @param {RawImage[]} images The image(s) to extract features from.
* @param {*} input_points A 3D or 4D array, representing the input points provided by the user.
* - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1.
* - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`.
* @returns {Promise<SamImageProcessorResult>}
*/
async _call(images, input_points) {
let {
pixel_values,
original_sizes,
reshaped_input_sizes,
} = await super._call(images);
/**
*
* @param {any} input_points
* @param {HeightWidth[]} original_sizes
* @param {HeightWidth[]} reshaped_input_sizes
* @returns {Tensor}
*/
reshape_input_points(input_points, original_sizes, reshaped_input_sizes) {
// Make deep copy to avoid altering user's input
input_points = structuredClone(input_points);
let shape = calculateDimensions(input_points);
// TODO: add support for 2D input_points
@ -1170,26 +1169,68 @@ export class SamImageProcessor extends ImageFeatureExtractor {
}
}
let input_points_tensor = new Tensor(
'int64',
BigInt64Array.from(input_points.flat(Infinity)
.map(x => BigInt(Math.round(x)))),
return new Tensor(
'float32',
Float32Array.from(input_points.flat(Infinity)),
shape
)
// TODO: allowed to be floats?
// let input_points_tensor = new Tensor(
// 'float32',
// Float32Array.from(input_points.flat(Infinity)),
// shape
// )
}
return {
pixel_values,
original_sizes: original_sizes,
reshaped_input_sizes: reshaped_input_sizes,
input_points: input_points_tensor
/**
*
* @param {any} input_labels
* @param {Tensor} input_points
* @returns {Tensor}
*/
add_input_labels(input_labels, input_points) {
let shape = calculateDimensions(input_labels);
if (shape.length === 2) {
// Correct user's input
shape = [1, ...shape];
input_labels = [input_labels];
} else if (shape.length !== 3) {
throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.")
}
if (shape.some((x, i) => x !== input_points.dims[i])) {
throw Error(`The first ${shape.length} dimensions of 'input_points' and 'input_labels' must be the same.`)
}
return new Tensor(
'int64',
input_labels.flat(Infinity).map(BigInt),
shape,
)
}
/**
* @param {any[]} images The URL(s) of the image(s) to extract features from.
* @param {any} [input_points] A 3D or 4D array, representing the input points provided by the user.
* - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1.
* - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`.
* @param {any} [input_labels] A 2D or 3D array, representing the input labels for the points, used by the prompt encoder to encode the prompt.
* - 2D: `[point_batch_size, nb_points_per_image]`. In this case, `batch_size` is assumed to be 1.
* - 3D: `[batch_size, point_batch_size, nb_points_per_image]`.
* @returns {Promise<SamImageProcessorResult>}
*/
async _call(images, input_points = null, input_labels = null) {
// TODO allow user to use preprocessed images
/** @type {SamImageProcessorResult} */
const processed = await super._call(images);
if (input_points) {
processed.input_points = this.reshape_input_points(
input_points, processed.original_sizes, processed.reshaped_input_sizes
);
}
if (input_labels) {
if (!processed.input_points) {
throw Error("`input_points` must be provided if `input_labels` are provided.")
}
processed.input_labels = this.add_input_labels(input_labels, processed.input_points);
}
return processed;
}
/**
@ -1212,22 +1253,22 @@ export class SamImageProcessor extends ImageFeatureExtractor {
} = {}) {
// masks: [1, 1, 3, 256, 256]
let output_masks = [];
const output_masks = [];
pad_size = pad_size ?? this.pad_size;
let target_image_size = [pad_size.height, pad_size.width];
const target_image_size = [pad_size.height, pad_size.width];
for (let i = 0; i < original_sizes.length; ++i) {
let original_size = original_sizes[i];
let reshaped_input_size = reshaped_input_sizes[i];
const original_size = original_sizes[i];
const reshaped_input_size = reshaped_input_sizes[i];
let mask = masks[i]; // [b, c, h, w]
const mask = masks[i]; // [b, c, h, w]
// TODO: improve
let interpolated_masks = [];
const interpolated_masks = [];
for (let j = 0; j < mask.dims[0]; ++j) {
let m = mask[j]; // 3d tensor
const m = mask[j]; // 3d tensor
// Upscale mask to padded size
let interpolated_mask = interpolate(m, target_image_size, 'bilinear', false);
@ -1236,28 +1277,29 @@ export class SamImageProcessor extends ImageFeatureExtractor {
interpolated_mask = interpolated_mask.slice(null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]);
// Downscale mask
interpolated_mask = interpolate(mask, original_size, 'bilinear', false);
interpolated_mask = interpolate(interpolated_mask, original_size, 'bilinear', false);
if (binarize) {
const binarizedMaskData = new Uint8Array(interpolated_mask.data.length);
for (let i = 0; i < interpolated_mask.data.length; ++i) {
if (interpolated_mask.data[i] > mask_threshold) {
binarizedMaskData[i] = 1;
}
}
interpolated_mask = new Tensor(
'bool',
Array.from(interpolated_mask.data).map(x => x > mask_threshold),
binarizedMaskData,
interpolated_mask.dims
)
}
// add back batch dim for concat
interpolated_mask.dims = [1, ...interpolated_mask.dims];
interpolated_masks.push(interpolated_mask);
}
let concatenated = cat(interpolated_masks);
output_masks.push(concatenated);
output_masks.push(stack(interpolated_masks));
}
return output_masks;
}
}
@ -1732,12 +1774,10 @@ export class Processor extends Callable {
export class SamProcessor extends Processor {
/**
* @param {*} images
* @param {*} input_points
* @returns {Promise<any>}
* @borrows SamImageProcessor#_call as _call
*/
async _call(images, input_points) {
return await this.feature_extractor(images, input_points);
async _call(...args) {
return await this.feature_extractor(...args);
}
/**
@ -1747,6 +1787,13 @@ export class SamProcessor extends Processor {
// @ts-ignore
return this.feature_extractor.post_process_masks(...args);
}
/**
* @borrows SamImageProcessor#reshape_input_points as reshape_input_points
*/
reshape_input_points(...args) {
// @ts-ignore
return this.feature_extractor.reshape_input_points(...args);
}
}
/**

View File

@ -109,8 +109,8 @@ export function exists(x) {
/**
* Calculates the dimensions of a nested array.
*
* @param {Array} arr The nested array to calculate dimensions for.
* @returns {Array} An array containing the dimensions of the input array.
* @param {any[]} arr The nested array to calculate dimensions for.
* @returns {number[]} An array containing the dimensions of the input array.
*/
export function calculateDimensions(arr) {
const dimensions = [];

View File

@ -169,6 +169,10 @@ export class RawImage {
* @param {import('./tensor.js').Tensor} tensor
*/
static fromTensor(tensor, channel_format = 'CHW') {
if (tensor.dims.length !== 3) {
throw new Error(`Tensor should have 3 dimensions, but has ${tensor.dims.length} dimensions.`);
}
if (channel_format === 'CHW') {
tensor = tensor.transpose(1, 2, 0);
} else if (channel_format === 'HWC') {

View File

@ -89,11 +89,33 @@ describe('Processors', () => {
it(MODELS.sam, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.sam))
{ // Basic test
{ // without input points
const image = await load_image(TEST_IMAGES.pattern_3x3);
const { pixel_values } = await processor(image, [[[0, 0]]]);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);
compare(pixel_values.dims, [1, 3, 1024, 1024]);
compare(avg(pixel_values.data), -0.4505715670146813);
compare(original_sizes, [[3, 3]]);
compare(reshaped_input_sizes, [[1024, 1024]]);
}
{ // with input points
const image = await load_image(TEST_IMAGES.pattern_3x3);
const { original_sizes, reshaped_input_sizes, input_points } = await processor(image, [[[1, 2]]]);
compare(original_sizes, [[3, 3]]);
compare(reshaped_input_sizes, [[1024, 1024]]);
compare(input_points.tolist(), [[[[341.3333, 682.6667]]]]);
}
{ // multiple points with labels
const image = await load_image(TEST_IMAGES.pattern_3x3);
const { original_sizes, reshaped_input_sizes, input_points, input_labels } = await processor(image, [[[1, 2], [2, 1]]], [[1, 0]]);
compare(original_sizes, [[3, 3]]);
compare(reshaped_input_sizes, [[1024, 1024]]);
compare(input_points.tolist(), [[[[341.3333, 682.6667], [682.6667, 341.3333]]]]);
compare(input_labels.tolist(), [[[1n, 0n]]]);
}
}, MAX_TEST_EXECUTION_TIME);