From cdcbfc125ce1dcaaf62a6bb3050dc774ef08b74c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 10 Jan 2024 17:47:21 +0200 Subject: [PATCH] Add support for Segment Anything Model (#510) * Update SamModel * Make `AutoModel.from_pretrained` work with SamModel * Add listed support for SAM (Segment Anything Model) * Update types of `calculateDimensions` * Throw error if reading image from tensor with dims.length != 3 * Make SamProcessor input points optional * Fix type errors * `let` -> `const` * `cat` -> `stack` * Expose `reshape_input_points` in `SamProcessor` * Add `input_labels` input parameter for SAM * Add `input_labels` to sam processor * Update SAM unit tests * Remove TODOs * Update JSDoc --- README.md | 1 + docs/snippets/6_supported-models.snippet | 1 + scripts/supported_models.py | 19 ++- src/models.js | 143 ++++++++++++++++++++-- src/processors.js | 145 +++++++++++++++-------- src/utils/core.js | 4 +- src/utils/image.js | 4 + tests/processors.test.js | 26 +++- 8 files changed, 277 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index b2994d6..00b4d39 100644 --- a/README.md +++ b/README.md @@ -328,6 +328,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. +1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer. 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index a5fe13a..d2054b9 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -63,6 +63,7 @@ 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. +1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer. 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. diff --git a/scripts/supported_models.py b/scripts/supported_models.py index b5d0912..0415c8c 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -745,11 +745,20 @@ SUPPORTED_MODELS = { 'distilroberta-base', ], }, - # 'sam': [ - # 'facebook/sam-vit-base', - # 'facebook/sam-vit-large', - # 'facebook/sam-vit-huge', - # ], + 'sam': { + # Mask generation + 'mask-generation': [ + # SAM + 'facebook/sam-vit-base', + 'facebook/sam-vit-large', + 'facebook/sam-vit-huge', + 'wanglab/medsam-vit-base', + + # SlimSAM + 'nielsr/slimsam-50-uniform', + 'nielsr/slimsam-77-uniform', + ], + }, 'segformer': { # Image segmentation 'image-segmentation': [ diff --git a/src/models.js b/src/models.js index 802f057..ae56ea9 100644 --- a/src/models.js +++ b/src/models.js @@ -95,6 +95,7 @@ const MODEL_TYPES = { Seq2Seq: 2, Vision2Seq: 3, DecoderOnly: 4, + MaskGeneration: 5, } ////////////////////////////////////////////////// @@ -771,6 +772,13 @@ export class PreTrainedModel extends Callable { getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); + } else if (modelType === MODEL_TYPES.MaskGeneration) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'vision_encoder', options), + constructSession(pretrained_model_name_or_path, 'prompt_encoder_mask_decoder', options), + ]); + } else if (modelType === MODEL_TYPES.EncoderDecoder) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), @@ -4242,12 +4250,130 @@ export class YolosObjectDetectionOutput extends ModelOutput { ////////////////////////////////////////////////// export class SamPreTrainedModel extends PreTrainedModel { } + +/** + * Segment Anything Model (SAM) for generating segmentation masks, given an input image + * and optional 2D location and bounding boxes. + * + * **Example:** Perform mask generation w/ `Xenova/sam-vit-base`. + * ```javascript + * import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers'; + * + * const model = await SamModel.from_pretrained('Xenova/sam-vit-base'); + * const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base'); + * + * const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png'; + * const raw_image = await RawImage.read(img_url); + * const input_points = [[[450, 600]]] // 2D localization of a window + * + * const inputs = await processor(raw_image, input_points); + * const outputs = await model(inputs); + * + * const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes); + * // [ + * // Tensor { + * // dims: [ 1, 3, 1764, 2646 ], + * // type: 'bool', + * // data: Uint8Array(14002632) [ ... ], + * // size: 14002632 + * // } + * // ] + * const scores = outputs.iou_scores; + * // Tensor { + * // dims: [ 1, 1, 3 ], + * // type: 'float32', + * // data: Float32Array(3) [ + * // 0.8892380595207214, + * // 0.9311248064041138, + * // 0.983696699142456 + * // ], + * // size: 3 + * // } + * ``` + */ export class SamModel extends SamPreTrainedModel { /** - * @param {Object} model_inputs - * @param {Tensor} model_inputs.pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`. - * @param {Tensor} model_inputs.input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. This is used by the prompt encoder to encode the prompt. - * @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`. + * Creates a new instance of the `SamModel` class. + * @param {Object} config The configuration object specifying the hyperparameters and other model settings. + * @param {Object} vision_encoder The ONNX session containing the vision encoder model. + * @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model. + */ + constructor(config, vision_encoder, prompt_encoder_mask_decoder) { + super(config, vision_encoder); + this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder; + } + + /** + * Compute image embeddings and positional image embeddings, given the pixel values of an image. + * @param {Object} model_inputs Object containing the model inputs. + * @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `SamProcessor`. + * @returns {Promise<{ image_embeddings: Tensor, image_positional_embeddings: Tensor }>} The image embeddings and positional image embeddings. + */ + async get_image_embeddings({ pixel_values }) { + // in: + // - pixel_values: tensor.float32[batch_size,3,1024,1024] + // + // out: + // - image_embeddings: tensor.float32[batch_size,256,64,64] + // - image_positional_embeddings: tensor.float32[batch_size,256,64,64] + return await encoderForward(this, { pixel_values }) + } + + /** + * @typedef {Object} SamModelInputs Object containing the model inputs. + * @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`. + * These can be obtained using a `SamProcessor`. + * @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. + * This is used by the prompt encoder to encode the prompt. + * @property {Tensor} [input_labels] Input labels for the points, as a Tensor of shape `(batch_size, point_batch_size, num_points)`. + * This is used by the prompt encoder to encode the prompt. There are 4 types of labels: + * - `1`: the point is a point that contains the object of interest + * - `0`: the point is a point that does not contain the object of interest + * - `-1`: the point corresponds to the background + * - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + * @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder. + * @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder. + */ + + /** + * @param {SamModelInputs} model_inputs Object containing the model inputs. + * @returns {Promise} The output of the model. + */ + async forward(model_inputs) { + if (!model_inputs.image_embeddings || !model_inputs.image_positional_embeddings) { + // Compute the image embeddings if they are missing + model_inputs = { + ...model_inputs, + ...(await this.get_image_embeddings(model_inputs)) + } + } + + if (!model_inputs.input_labels) { + // Set default input labels if they are missing + const shape = model_inputs.input_points.dims.slice(0, -1); + const numElements = shape.reduce((a, b) => a * b, 1); + model_inputs.input_labels = new Tensor( + 'int64', + new BigInt64Array(numElements).fill(1n), + shape + ); + } + + // Returns: + // - iou_scores: tensor.float32[batch_size,point_batch_size,3] + // - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256] + return await sessionRun(this.prompt_encoder_mask_decoder, { + input_points: model_inputs.input_points, + input_labels: model_inputs.input_labels, + image_embeddings: model_inputs.image_embeddings, + image_positional_embeddings: model_inputs.image_positional_embeddings, + }); + } + + /** + * Runs the model with the provided inputs + * @param {Object} model_inputs Model inputs + * @returns {Promise} Object containing segmentation outputs */ async _call(model_inputs) { return new SamImageSegmentationOutput(await super._call(model_inputs)); @@ -5049,7 +5175,6 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]], - ['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly ]); const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ @@ -5290,7 +5415,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], - [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration], [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], @@ -5329,7 +5454,9 @@ for (const [name, model, type] of CUSTOM_MAPPING) { * let model = await AutoModel.from_pretrained('bert-base-uncased'); */ export class AutoModel extends PretrainedMixin { - static MODEL_CLASS_MAPPINGS = [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_MAPPING_NAMES_DECODER_ONLY]; + /** @type {Map[]} */ + // @ts-ignore + static MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map(x => x[0]); static BASE_IF_FAIL = true; } @@ -5493,7 +5620,7 @@ export class AutoModelForZeroShotObjectDetection extends PretrainedMixin { /** - * Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function. + * Helper class which is used to instantiate pretrained mask generation models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. * * @example diff --git a/src/processors.js b/src/processors.js index 499b952..6165e27 100644 --- a/src/processors.js +++ b/src/processors.js @@ -1122,24 +1122,23 @@ export class YolosFeatureExtractor extends ImageFeatureExtractor { * @property {Tensor} pixel_values * @property {HeightWidth[]} original_sizes * @property {HeightWidth[]} reshaped_input_sizes - * @property {Tensor} input_points + * @property {Tensor} [input_points] + * @property {Tensor} [input_labels] */ export class SamImageProcessor extends ImageFeatureExtractor { - /** - * @param {RawImage[]} images The image(s) to extract features from. - * @param {*} input_points A 3D or 4D array, representing the input points provided by the user. - * - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1. - * - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`. - * @returns {Promise} - */ - async _call(images, input_points) { - let { - pixel_values, - original_sizes, - reshaped_input_sizes, - } = await super._call(images); + /** + * + * @param {any} input_points + * @param {HeightWidth[]} original_sizes + * @param {HeightWidth[]} reshaped_input_sizes + * @returns {Tensor} + */ + reshape_input_points(input_points, original_sizes, reshaped_input_sizes) { + + // Make deep copy to avoid altering user's input + input_points = structuredClone(input_points); let shape = calculateDimensions(input_points); // TODO: add support for 2D input_points @@ -1170,26 +1169,68 @@ export class SamImageProcessor extends ImageFeatureExtractor { } } - let input_points_tensor = new Tensor( - 'int64', - BigInt64Array.from(input_points.flat(Infinity) - .map(x => BigInt(Math.round(x)))), + return new Tensor( + 'float32', + Float32Array.from(input_points.flat(Infinity)), shape ) - // TODO: allowed to be floats? - // let input_points_tensor = new Tensor( - // 'float32', - // Float32Array.from(input_points.flat(Infinity)), - // shape - // ) + } - return { - pixel_values, - original_sizes: original_sizes, - reshaped_input_sizes: reshaped_input_sizes, - input_points: input_points_tensor + /** + * + * @param {any} input_labels + * @param {Tensor} input_points + * @returns {Tensor} + */ + add_input_labels(input_labels, input_points) { + let shape = calculateDimensions(input_labels); + if (shape.length === 2) { + // Correct user's input + shape = [1, ...shape]; + input_labels = [input_labels]; + } else if (shape.length !== 3) { + throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.") } + + if (shape.some((x, i) => x !== input_points.dims[i])) { + throw Error(`The first ${shape.length} dimensions of 'input_points' and 'input_labels' must be the same.`) + } + return new Tensor( + 'int64', + input_labels.flat(Infinity).map(BigInt), + shape, + ) + } + /** + * @param {any[]} images The URL(s) of the image(s) to extract features from. + * @param {any} [input_points] A 3D or 4D array, representing the input points provided by the user. + * - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1. + * - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`. + * @param {any} [input_labels] A 2D or 3D array, representing the input labels for the points, used by the prompt encoder to encode the prompt. + * - 2D: `[point_batch_size, nb_points_per_image]`. In this case, `batch_size` is assumed to be 1. + * - 3D: `[batch_size, point_batch_size, nb_points_per_image]`. + * @returns {Promise} + */ + async _call(images, input_points = null, input_labels = null) { + // TODO allow user to use preprocessed images + /** @type {SamImageProcessorResult} */ + const processed = await super._call(images); + + if (input_points) { + processed.input_points = this.reshape_input_points( + input_points, processed.original_sizes, processed.reshaped_input_sizes + ); + } + + if (input_labels) { + if (!processed.input_points) { + throw Error("`input_points` must be provided if `input_labels` are provided.") + } + processed.input_labels = this.add_input_labels(input_labels, processed.input_points); + } + + return processed; } /** @@ -1212,22 +1253,22 @@ export class SamImageProcessor extends ImageFeatureExtractor { } = {}) { // masks: [1, 1, 3, 256, 256] - let output_masks = []; + const output_masks = []; pad_size = pad_size ?? this.pad_size; - let target_image_size = [pad_size.height, pad_size.width]; + const target_image_size = [pad_size.height, pad_size.width]; for (let i = 0; i < original_sizes.length; ++i) { - let original_size = original_sizes[i]; - let reshaped_input_size = reshaped_input_sizes[i]; + const original_size = original_sizes[i]; + const reshaped_input_size = reshaped_input_sizes[i]; - let mask = masks[i]; // [b, c, h, w] + const mask = masks[i]; // [b, c, h, w] // TODO: improve - let interpolated_masks = []; + const interpolated_masks = []; for (let j = 0; j < mask.dims[0]; ++j) { - let m = mask[j]; // 3d tensor + const m = mask[j]; // 3d tensor // Upscale mask to padded size let interpolated_mask = interpolate(m, target_image_size, 'bilinear', false); @@ -1236,28 +1277,29 @@ export class SamImageProcessor extends ImageFeatureExtractor { interpolated_mask = interpolated_mask.slice(null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]); // Downscale mask - interpolated_mask = interpolate(mask, original_size, 'bilinear', false); + interpolated_mask = interpolate(interpolated_mask, original_size, 'bilinear', false); if (binarize) { + const binarizedMaskData = new Uint8Array(interpolated_mask.data.length); + for (let i = 0; i < interpolated_mask.data.length; ++i) { + if (interpolated_mask.data[i] > mask_threshold) { + binarizedMaskData[i] = 1; + } + } interpolated_mask = new Tensor( 'bool', - Array.from(interpolated_mask.data).map(x => x > mask_threshold), + binarizedMaskData, interpolated_mask.dims ) } - // add back batch dim for concat - interpolated_mask.dims = [1, ...interpolated_mask.dims]; - interpolated_masks.push(interpolated_mask); } - let concatenated = cat(interpolated_masks); - output_masks.push(concatenated); + output_masks.push(stack(interpolated_masks)); } return output_masks; - } } @@ -1732,12 +1774,10 @@ export class Processor extends Callable { export class SamProcessor extends Processor { /** - * @param {*} images - * @param {*} input_points - * @returns {Promise} + * @borrows SamImageProcessor#_call as _call */ - async _call(images, input_points) { - return await this.feature_extractor(images, input_points); + async _call(...args) { + return await this.feature_extractor(...args); } /** @@ -1747,6 +1787,13 @@ export class SamProcessor extends Processor { // @ts-ignore return this.feature_extractor.post_process_masks(...args); } + /** + * @borrows SamImageProcessor#reshape_input_points as reshape_input_points + */ + reshape_input_points(...args) { + // @ts-ignore + return this.feature_extractor.reshape_input_points(...args); + } } /** diff --git a/src/utils/core.js b/src/utils/core.js index d05ade9..4ed0f15 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -109,8 +109,8 @@ export function exists(x) { /** * Calculates the dimensions of a nested array. * - * @param {Array} arr The nested array to calculate dimensions for. - * @returns {Array} An array containing the dimensions of the input array. + * @param {any[]} arr The nested array to calculate dimensions for. + * @returns {number[]} An array containing the dimensions of the input array. */ export function calculateDimensions(arr) { const dimensions = []; diff --git a/src/utils/image.js b/src/utils/image.js index 369670a..2d12cb8 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -169,6 +169,10 @@ export class RawImage { * @param {import('./tensor.js').Tensor} tensor */ static fromTensor(tensor, channel_format = 'CHW') { + if (tensor.dims.length !== 3) { + throw new Error(`Tensor should have 3 dimensions, but has ${tensor.dims.length} dimensions.`); + } + if (channel_format === 'CHW') { tensor = tensor.transpose(1, 2, 0); } else if (channel_format === 'HWC') { diff --git a/tests/processors.test.js b/tests/processors.test.js index 9b25113..9bbc7f8 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -89,11 +89,33 @@ describe('Processors', () => { it(MODELS.sam, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.sam)) - { // Basic test + { // without input points const image = await load_image(TEST_IMAGES.pattern_3x3); - const { pixel_values } = await processor(image, [[[0, 0]]]); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); compare(pixel_values.dims, [1, 3, 1024, 1024]); compare(avg(pixel_values.data), -0.4505715670146813); + + compare(original_sizes, [[3, 3]]); + compare(reshaped_input_sizes, [[1024, 1024]]); + } + + { // with input points + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { original_sizes, reshaped_input_sizes, input_points } = await processor(image, [[[1, 2]]]); + + compare(original_sizes, [[3, 3]]); + compare(reshaped_input_sizes, [[1024, 1024]]); + compare(input_points.tolist(), [[[[341.3333, 682.6667]]]]); + } + + { // multiple points with labels + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { original_sizes, reshaped_input_sizes, input_points, input_labels } = await processor(image, [[[1, 2], [2, 1]]], [[1, 0]]); + + compare(original_sizes, [[3, 3]]); + compare(reshaped_input_sizes, [[1024, 1024]]); + compare(input_points.tolist(), [[[[341.3333, 682.6667], [682.6667, 341.3333]]]]); + compare(input_labels.tolist(), [[[1n, 0n]]]); } }, MAX_TEST_EXECUTION_TIME);