Only run image models with required inputs

This commit is contained in:
Joshua Lochner 2023-05-16 17:32:31 +02:00
parent 5411a4599d
commit b57c2a9027
1 changed files with 13 additions and 10 deletions

View File

@ -870,7 +870,7 @@ export class ImageToTextPipeline extends Pipeline {
images = await prepareImages(images);
let pixel_values = (await this.processor(images)).pixel_values;
let { pixel_values } = await this.processor(images);
let toReturn = [];
for (let batch of pixel_values) {
@ -918,8 +918,8 @@ export class ImageClassificationPipeline extends Pipeline {
let isBatched = Array.isArray(images);
images = await prepareImages(images);
let inputs = await this.processor(images);
let output = await this.model(inputs);
let { pixel_values } = await this.processor(images);
let output = await this.model({ pixel_values });
let id2label = this.model.config.id2label;
let toReturn = [];
@ -997,8 +997,8 @@ export class ImageSegmentationPipeline extends Pipeline {
images = await prepareImages(images);
let imageSizes = images.map(x => [x.height, x.width]);
let inputs = await this.processor(images);
let output = await this.model(inputs);
let { pixel_values, pixel_mask} = await this.processor(images);
let output = await this.model({ pixel_values, pixel_mask});
let fn = null;
if (subtask !== null) {
@ -1104,10 +1104,13 @@ export class ZeroShotImageClassificationPipeline extends Pipeline {
truncation: true
});
// Compare each image with each candidate label
let image_inputs = await this.processor(images);
let output = await this.model({ ...text_inputs, ...image_inputs });
// Run processor
let { pixel_values } = await this.processor(images);
// Run model with both text and pixel inputs
let output = await this.model({ ...text_inputs, pixel_values });
// Compare each image with each candidate label
let toReturn = [];
for (let batch of output.logits_per_image) {
// Compute softmax per image
@ -1162,8 +1165,8 @@ export class ObjectDetectionPipeline extends Pipeline {
let imageSizes = percentage ? null : images.map(x => [x.height, x.width]);
let inputs = await this.processor(images);
let output = await this.model(inputs);
let { pixel_values, pixel_mask } = await this.processor(images);
let output = await this.model({ pixel_values, pixel_mask });
// @ts-ignore
let processed = this.processor.feature_extractor.post_process_object_detection(output, threshold, imageSizes);