Use interpolation method specified in processor's config

This commit is contained in:
Joshua Lochner 2023-05-02 06:49:54 +02:00
parent 8de18bc1d5
commit 4a282bf632
3 changed files with 102 additions and 68 deletions

View File

@ -37,6 +37,16 @@ if (BROWSER_ENV) {
} }
// Defined here: https://github.com/python-pillow/Pillow/blob/a405e8406b83f8bfb8916e93971edc7407b8b1ff/src/libImaging/Imaging.h#L262-L268
const RESAMPLING_MAPPING = {
0: 'nearest',
1: 'lanczos',
2: 'bilinear',
3: 'bicubic',
4: 'box',
5: 'hamming',
}
export class CustomImage { export class CustomImage {
/** /**
@ -202,14 +212,19 @@ export class CustomImage {
* @param {number} width - The width of the new image. * @param {number} width - The width of the new image.
* @param {number} height - The height of the new image. * @param {number} height - The height of the new image.
* @param {object} options - Additional options for resizing. * @param {object} options - Additional options for resizing.
* @param {string} [options.resample] - The resampling method to use. Can be one of `nearest`, `bilinear`, `bicubic`. * @param {0|1|2|3|4|5|string} [options.resample] - The resampling method to use.
* @returns {Promise<CustomImage>} - `this` to support chaining. * @returns {Promise<CustomImage>} - `this` to support chaining.
*/ */
async resize(width, height, { async resize(width, height, {
// TODO: Use `resample` resample = 2,
resample = 'bilinear',
} = {}) { } = {}) {
// Ensure resample method is a string
let resampleMethod = RESAMPLING_MAPPING[resample] ?? resample;
if (BROWSER_ENV) { if (BROWSER_ENV) {
// TODO use `resample` in browser environment
// Store number of channels before resizing // Store number of channels before resizing
let numChannels = this.channels; let numChannels = this.channels;
@ -236,12 +251,40 @@ export class CustomImage {
height: this.height, height: this.height,
channels: this.channels channels: this.channels
} }
}).resize({
// https://github.com/lovell/sharp/blob/main/docs/api-resize.md
width, height,
fit: 'fill',
kernel: 'cubic'
}); });
switch (resampleMethod) {
case 'box':
case 'hamming':
if (resampleMethod === 'box' || resampleMethod === 'hamming') {
console.warn(`Resampling method ${resampleMethod} is not yet supported. Using bilinear instead.`);
resampleMethod = 'bilinear';
}
case 'nearest':
case 'bilinear':
case 'bicubic':
// Perform resizing using affine transform.
// This matches how the python Pillow library does it.
img = img.affine([width / this.width, 0, 0, height / this.height], {
interpolator: resampleMethod
});
break;
case 'lanczos':
// https://github.com/python-pillow/Pillow/discussions/5519
// https://github.com/lovell/sharp/blob/main/docs/api-resize.md
img = img.resize({
width, height,
fit: 'fill',
kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3
});
break;
default:
throw new Error(`Resampling method ${resampleMethod} is not supported.`);
}
return await loadImageFunction(img); return await loadImageFunction(img);
} }

View File

@ -43,16 +43,6 @@ class FeatureExtractor extends Callable {
*/ */
class ImageFeatureExtractor extends FeatureExtractor { class ImageFeatureExtractor extends FeatureExtractor {
// Defined here: https://github.com/python-pillow/Pillow/blob/a405e8406b83f8bfb8916e93971edc7407b8b1ff/src/libImaging/Imaging.h#L262-L268
RESAMPLING_MAPPING = {
0: 'nearest',
1: 'lanczos',
2: 'bilinear',
3: 'bicubic',
4: 'box',
5: 'hamming',
}
/** /**
* Constructs a new ViTFeatureExtractor instance. * Constructs a new ViTFeatureExtractor instance.
* *

View File

@ -594,19 +594,19 @@ describe('Pipelines', () => {
]); ]);
expect(output2).toEqual([ expect(output2).toEqual([
{ "generated_text": "a herd of giraffes and zebras grazing in a field" }, { "generated_text": "a herd of giraffes and zebras grazing in a field" },
{ "generated_text": "a herd of giraffes and zebras grazing in a grassy field" } { "generated_text": "a herd of giraffes and zebras in a grassy field" }
]); ]);
expect(output3).toEqual([ expect(output3).toEqual([
[{ "generated_text": "two men are playing soccer on a field" }], [{ "generated_text": "two men are kicking a soccer ball in a soccer game" }],
[{ "generated_text": "a plane is parked on the tarmac at an airport" }] [{ "generated_text": "a plane on the tarmac with a passenger bus" }]
]); ]);
expect(output4).toEqual([ expect(output4).toEqual([
[ [
{ "generated_text": "a man kicking a soccer ball on a field" }, { "generated_text": "two men are kicking a soccer ball on a field" },
{ "generated_text": "a man kicking a soccer ball in a soccer game" } { "generated_text": "two men are kicking a soccer ball in a soccer game" }
], [ ], [
{ "generated_text": "a large jetliner sitting on top of an airport tarmac" }, { "generated_text": "a plane on a tarmac with a group of buses" },
{ "generated_text": "a large jetliner sitting on top of a tarmac" } { "generated_text": "a plane on a tarmac with a group of people on the ground" }
] ]
]); ]);
}, MAX_TEST_EXECUTION_TIME); }, MAX_TEST_EXECUTION_TIME);
@ -636,25 +636,25 @@ describe('Pipelines', () => {
await classifier.dispose(); await classifier.dispose();
expect(output1).toEqual([ expect(output1).toEqual([
{ "label": "tiger, Panthera tigris", "score": 0.8053199648857117 } { "label": "tiger, Panthera tigris", "score": 0.607988178730011 }
]); ]);
expect(output2).toEqual([ expect(output2).toEqual([
{ "label": "tiger, Panthera tigris", "score": 0.8053199648857117 }, { "label": "tiger, Panthera tigris", "score": 0.607988178730011 },
{ "label": "tiger cat", "score": 0.1903550773859024 } { "label": "tiger cat", "score": 0.3877776563167572 }
]); ]);
expect(output3).toEqual([ expect(output3).toEqual([
{ "label": "palace", "score": 0.9973987340927124 }, { "label": "palace", "score": 0.9986862540245056 },
{ "label": "teapot", "score": 0.9834653735160828 } { "label": "teapot", "score": 0.987880527973175 }
]); ]);
expect(output4).toEqual([ expect(output4).toEqual([
[ [
{ "label": "palace", "score": 0.9973987340927124 }, { "label": "palace", "score": 0.9986862540245056 },
{ "label": "castle", "score": 0.0007704473682679236 } { "label": "castle", "score": 0.00037879671435803175 }
], ],
[ [
{ "label": "teapot", "score": 0.9834653735160828 }, { "label": "teapot", "score": 0.987880527973175 },
{ "label": "coffeepot", "score": 0.009658231399953365 } { "label": "coffeepot", "score": 0.006591461598873138 }
] ]
]); ]);
}, MAX_TEST_EXECUTION_TIME); }, MAX_TEST_EXECUTION_TIME);
@ -683,11 +683,11 @@ describe('Pipelines', () => {
await segmenter.dispose(); await segmenter.dispose();
expect(outputs).toEqual([ expect(outputs).toEqual([
{ score: 0.9967542886734009, label: 'cat', mask: 58920 }, { score: 0.9916538596153259, label: 'cat', mask: 58998 },
{ score: 0.9985942840576172, label: 'remote', mask: 4239 }, { score: 0.9987397789955139, label: 'remote', mask: 4164 },
{ score: 0.9994107484817505, label: 'remote', mask: 2271 }, { score: 0.9994599223136902, label: 'remote', mask: 2275 },
{ score: 0.9628002643585205, label: 'couch', mask: 171963 }, { score: 0.9730215072631836, label: 'couch', mask: 176980 },
{ score: 0.9995519518852234, label: 'cat', mask: 52358 } { score: 0.9993911385536194, label: 'cat', mask: 52670 }
]); ]);
}, MAX_TEST_EXECUTION_TIME); }, MAX_TEST_EXECUTION_TIME);
}); });
@ -713,24 +713,24 @@ describe('Pipelines', () => {
await classifier.dispose(); await classifier.dispose();
expect(output1).toEqual([ expect(output1).toEqual([
{ "score": 0.9859885573387146, "label": "football" }, { "score": 0.992206871509552, "label": "football" },
{ "score": 0.0027391533367335796, "label": "airport" }, { "score": 0.0013248942559584975, "label": "airport" },
{ "score": 0.011272300966084003, "label": "animals" } { "score": 0.006468251813203096, "label": "animals" }
]); ]);
expect(output2).toEqual([ expect(output2).toEqual([
[ [
{ "score": 0.9864809513092041, "label": "football" }, { "score": 0.9919875860214233, "label": "football" },
{ "score": 0.002001031069085002, "label": "airport" }, { "score": 0.0012227334082126617, "label": "airport" },
{ "score": 0.011517995037138462, "label": "animals" } { "score": 0.006789708975702524, "label": "animals" }
], [ ], [
{ "score": 0.0006775481160730124, "label": "football" }, { "score": 0.0003043194592464715, "label": "football" },
{ "score": 0.9980292320251465, "label": "airport" }, { "score": 0.998708188533783, "label": "airport" },
{ "score": 0.0012932150857523084, "label": "animals" } { "score": 0.0009874969255179167, "label": "animals" }
], [ ], [
{ "score": 0.016665885224938393, "label": "football" }, { "score": 0.015163016505539417, "label": "football" },
{ "score": 0.018305325880646706, "label": "airport" }, { "score": 0.016037866473197937, "label": "airport" },
{ "score": 0.9650287628173828, "label": "animals" } { "score": 0.9687991142272949, "label": "animals" }
] ]
]); ]);
}, MAX_TEST_EXECUTION_TIME); }, MAX_TEST_EXECUTION_TIME);
@ -759,29 +759,30 @@ describe('Pipelines', () => {
expect(output1).toEqual({ expect(output1).toEqual({
"boxes": [ "boxes": [
[353.46766233444214, 246.3719218969345, 390.4700446128845, 317.3739963769913], [352.8210112452507, 247.36732184886932, 390.5271676182747, 318.09066116809845],
[13.694103956222534, 147.21754789352417, 209.85854387283325, 256.78999185562134], [111.15852802991867, 235.34255504608154, 224.96717244386673, 325.21119117736816],
[114.16613191366196, 235.96070230007172, 225.00602155923843, 324.78122770786285], [13.524770736694336, 146.81672930717468, 207.97560095787048, 278.6452639102936],
[191.83536261320114, 229.7595316171646, 314.76392537355423, 304.3607658147812], [187.396682202816, 227.97491312026978, 313.05202156305313, 300.26460886001587],
[365.9121733903885, 95.36925673484802, 526.8269795179367, 314.0325701236725] [201.60082161426544, 230.86223602294922, 312.1393972635269, 306.5505266189575],
[365.85242718458176, 95.3144109249115, 526.5485098958015, 313.17670941352844]
], ],
"classes": [24, 25, 24, 24, 25], "classes": [24, 24, 25, 24, 24, 25],
"scores": [0.9990849494934082, 0.9275544285774231, 0.9987317323684692, 0.9977309703826904, 0.9986600875854492], "scores": [0.9989480376243591, 0.9990893006324768, 0.9690554738044739, 0.9274907112121582, 0.9714975953102112, 0.9989491105079651],
"labels": ["zebra", "giraffe", "zebra", "zebra", "giraffe"] "labels": ["zebra", "zebra", "giraffe", "zebra", "zebra", "giraffe"]
}); });
expect(output2).toEqual([{ expect(output2).toEqual([{
"boxes": [ "boxes": [
[0.7251099348068237, 0.3263818323612213, 0.9811199903488159, 0.9992953240871429], [0.7231650948524475, 0.32641804218292236, 0.981127917766571, 0.9918863773345947],
[0.7574957609176636, 0.5250559747219086, 0.8291805982589722, 0.64716437458992], [0.7529061436653137, 0.52558633685112, 0.8229959607124329, 0.6482008993625641],
[0.5086115896701813, 0.5169457495212555, 0.5510353744029999, 0.5451572835445404], [0.5080368518829346, 0.5156279355287552, 0.5494132041931152, 0.5434067696332932],
[0.3393232971429825, 0.5198696702718735, 0.35605834424495697, 0.6132815033197403], [0.33636586368083954, 0.5217841267585754, 0.3535611182451248, 0.6151944994926453],
[0.42163804173469543, 0.4444425106048584, 0.5550684630870819, 0.521735429763794], [0.42090220749378204, 0.4482414871454239, 0.5515891760587692, 0.5207531303167343],
[0.20032313466072083, 0.4107527732849121, 0.4521646201610565, 0.5207630395889282], [0.1988394856452942, 0.41224047541618347, 0.45213085412979126, 0.5206181704998016],
[0.5072861611843109, 0.5172932296991348, 0.5501181185245514, 0.5445606559514999], [0.5063001662492752, 0.5170856416225433, 0.5478668659925461, 0.54373899102211],
[0.5722146779298782, 0.452553853392601, 0.7078577727079391, 0.6239819675683975] [0.5734506398439407, 0.4508090913295746, 0.7049560993909836, 0.6252130568027496],
], ],
"classes": [6, 1, 8, 1, 5, 5, 3, 6], "classes": [6, 1, 8, 1, 5, 5, 3, 6],
"scores": [0.9965458512306213, 0.9970073699951172, 0.9373661279678345, 0.9982954859733582, 0.9951448440551758, 0.9982654452323914, 0.963291347026825, 0.9981732368469238], "scores": [0.9970788359642029, 0.996989905834198, 0.9505048990249634, 0.9984546899795532, 0.9942372441291809, 0.9989550709724426, 0.938920259475708, 0.9992448091506958],
"labels": ["bus", "person", "truck", "person", "airplane", "airplane", "car", "bus"] "labels": ["bus", "person", "truck", "person", "airplane", "airplane", "car", "bus"]
}]); }]);
}, MAX_TEST_EXECUTION_TIME); }, MAX_TEST_EXECUTION_TIME);