Replace CommonJS imports/exports with ES6
This commit is contained in:
parent
1b9fe0f27c
commit
421aa331cd
|
@ -1,5 +1,8 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"checkJs": true
|
||||
"checkJs": true,
|
||||
"target": "esnext",
|
||||
"module": "esnext",
|
||||
"moduleResolution": "nodenext",
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
"version": "1.4.2",
|
||||
"description": "Run 🤗 Transformers in your browser! We currently support BERT, ALBERT, DistilBERT, MobileBERT, SqueezeBERT, T5, T5v1.1, FLAN-T5, mT5, BART, MarianMT, GPT2, GPT Neo, CodeGen, Whisper, CLIP, Vision Transformer, VisionEncoderDecoder, and DETR models, for a variety of tasks including: masked language modelling, text classification, token classification, zero-shot classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, image segmentation, and object detection.",
|
||||
"main": "./src/transformers.js",
|
||||
"type": "module",
|
||||
"directories": {
|
||||
"test": "tests"
|
||||
},
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
let ONNX;
|
||||
export let ONNX;
|
||||
|
||||
// TODO support more execution providers (e.g., webgpu)
|
||||
const executionProviders = ['wasm'];
|
||||
export const executionProviders = ['wasm'];
|
||||
|
||||
if (typeof process !== 'undefined') {
|
||||
// Running in a node-like environment.
|
||||
// Try to import onnxruntime-node, using onnxruntime-web as a fallback
|
||||
try {
|
||||
ONNX = require('onnxruntime-node');
|
||||
ONNX = (await import('onnxruntime-node')).default;
|
||||
} catch (err) {
|
||||
console.warn(
|
||||
"Node.js environment detected, but `onnxruntime-node` was not found. " +
|
||||
|
@ -20,7 +20,7 @@ if (typeof process !== 'undefined') {
|
|||
// @ts-ignore
|
||||
global.self = global;
|
||||
|
||||
ONNX = require('onnxruntime-web');
|
||||
ONNX = (await import('onnxruntime-web')).default;
|
||||
|
||||
// Disable spawning worker threads for testing.
|
||||
// This is done by setting numThreads to 1
|
||||
|
@ -33,10 +33,6 @@ if (typeof process !== 'undefined') {
|
|||
|
||||
} else {
|
||||
// Running in a browser-environment, so we just import `onnxruntime-web`
|
||||
ONNX = require('onnxruntime-web');
|
||||
ONNX = (await import('onnxruntime-web')).default;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
ONNX,
|
||||
executionProviders,
|
||||
}
|
||||
|
|
22
src/env.js
22
src/env.js
|
@ -1,7 +1,11 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
|
||||
const { env: onnx_env } = require('./backends/onnx.js').ONNX;
|
||||
import { ONNX } from './backends/onnx.js';
|
||||
const { env: onnx_env } = ONNX;
|
||||
|
||||
const __dirname = path.dirname(path.dirname(fileURLToPath(import.meta.url)));
|
||||
|
||||
// check if various APIs are available (depends on environment)
|
||||
const CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self;
|
||||
|
@ -13,7 +17,7 @@ const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE;
|
|||
// set local model path, based on available APIs
|
||||
const DEFAULT_LOCAL_PATH = '/models/onnx/quantized/';
|
||||
const localURL = RUNNING_LOCALLY
|
||||
? path.join(path.dirname(__dirname), DEFAULT_LOCAL_PATH)
|
||||
? path.join(__dirname, DEFAULT_LOCAL_PATH)
|
||||
: DEFAULT_LOCAL_PATH;
|
||||
|
||||
// First, set path to wasm files. This is needed when running in a web worker.
|
||||
|
@ -21,12 +25,12 @@ const localURL = RUNNING_LOCALLY
|
|||
// We use remote wasm files by default to make it easier for newer users.
|
||||
// In practice, users should probably self-host the necessary .wasm files.
|
||||
onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
|
||||
? path.join(path.dirname(__dirname), '/dist/')
|
||||
? path.join(__dirname, '/dist/')
|
||||
: 'https://cdn.jsdelivr.net/npm/@xenova/transformers/dist/';
|
||||
|
||||
|
||||
// Global variable used to control exection, with suitable defaults
|
||||
const env = {
|
||||
export const env = {
|
||||
// access onnxruntime-web's environment variables
|
||||
onnx: onnx_env,
|
||||
|
||||
|
@ -44,6 +48,9 @@ const env = {
|
|||
|
||||
// Whether to use the file system to load files. By default, it is true available.
|
||||
useFS: FS_AVAILABLE,
|
||||
|
||||
// Directory name of module. Useful for resolving local paths.
|
||||
__dirname,
|
||||
}
|
||||
|
||||
|
||||
|
@ -54,6 +61,3 @@ function isEmpty(obj) {
|
|||
return Object.keys(obj).length === 0;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
env
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
* FFT class provides functionality for performing Fast Fourier Transform on arrays
|
||||
* Code adapted from https://www.npmjs.com/package/fft.js
|
||||
*/
|
||||
class FFT {
|
||||
export default class FFT {
|
||||
/**
|
||||
* @param {number} size - The size of the input array. Must be a power of two and bigger than 1.
|
||||
* @throws {Error} FFT size must be a power of two and bigger than 1.
|
||||
|
@ -494,5 +494,3 @@ class FFT {
|
|||
out[outOff + 7] = T3r;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = FFT
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
const { Tensor } = require("./tensor_utils.js");
|
||||
const {
|
||||
import { Tensor } from './tensor_utils.js';
|
||||
import {
|
||||
Callable,
|
||||
exists,
|
||||
log_softmax
|
||||
} = require("./utils.js");
|
||||
} from './utils.js';
|
||||
|
||||
/**
|
||||
* A class representing a list of logits processors. A logits processor is a function that modifies the logits
|
||||
|
@ -12,7 +12,7 @@ const {
|
|||
*
|
||||
* @extends Callable
|
||||
*/
|
||||
class LogitsProcessorList extends Callable {
|
||||
export class LogitsProcessorList extends Callable {
|
||||
/**
|
||||
* Constructs a new instance of `LogitsProcessorList`.
|
||||
*/
|
||||
|
@ -66,7 +66,7 @@ class LogitsProcessorList extends Callable {
|
|||
* Base class for processing logits.
|
||||
* @extends Callable
|
||||
*/
|
||||
class LogitsProcessor extends Callable {
|
||||
export class LogitsProcessor extends Callable {
|
||||
/**
|
||||
* Apply the processor to the input logits.
|
||||
*
|
||||
|
@ -85,7 +85,7 @@ class LogitsProcessor extends Callable {
|
|||
*
|
||||
* @extends LogitsProcessor
|
||||
*/
|
||||
class ForceTokensLogitsProcessor extends LogitsProcessor {
|
||||
export class ForceTokensLogitsProcessor extends LogitsProcessor {
|
||||
/**
|
||||
* Constructs a new instance of `ForceTokensLogitsProcessor`.
|
||||
*
|
||||
|
@ -117,7 +117,7 @@ class ForceTokensLogitsProcessor extends LogitsProcessor {
|
|||
* A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
|
||||
* @extends LogitsProcessor
|
||||
*/
|
||||
class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
|
||||
export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
|
||||
/**
|
||||
* Create a ForcedBOSTokenLogitsProcessor.
|
||||
* @param {number} bos_token_id - The ID of the beginning-of-sequence token to be forced.
|
||||
|
@ -146,7 +146,7 @@ class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
|
|||
*
|
||||
* @extends LogitsProcessor
|
||||
*/
|
||||
class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
|
||||
export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
|
||||
/**
|
||||
* Create a ForcedEOSTokenLogitsProcessor.
|
||||
* @param {number} max_length - Max length of the sequence.
|
||||
|
@ -174,7 +174,7 @@ class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
|
|||
* A LogitsProcessor that handles adding timestamps to generated text.
|
||||
* @extends LogitsProcessor
|
||||
*/
|
||||
class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
||||
export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
||||
/**
|
||||
* Constructs a new WhisperTimeStampLogitsProcessor.
|
||||
* @param {object} generate_config - The config object passed to the `generate()` method of a transformer model.
|
||||
|
@ -249,7 +249,7 @@ class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|||
*
|
||||
* @extends LogitsProcessor
|
||||
*/
|
||||
class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
|
||||
export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
|
||||
/**
|
||||
* Create a NoRepeatNGramLogitsProcessor.
|
||||
* @param {number} no_repeat_ngram_size - The no-repeat-ngram size. All ngrams of this size can only occur once.
|
||||
|
@ -340,7 +340,7 @@ class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
|
|||
*
|
||||
* @extends LogitsProcessor
|
||||
*/
|
||||
class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
|
||||
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
|
||||
/**
|
||||
* Create a RepetitionPenaltyLogitsProcessor.
|
||||
* @param {number} penalty - The penalty to apply for repeated tokens.
|
||||
|
@ -372,7 +372,7 @@ class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
|
|||
}
|
||||
|
||||
|
||||
class GenerationConfig {
|
||||
export class GenerationConfig {
|
||||
constructor(kwargs = {}) {
|
||||
// Parameters that control the length of the output
|
||||
// TODO: extend the configuration with correct types
|
||||
|
@ -465,15 +465,3 @@ class GenerationConfig {
|
|||
this.generation_kwargs = kwargs.generation_kwargs ?? {};
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
GenerationConfig,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
ForceTokensLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor
|
||||
};
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
|
||||
const fs = require('fs');
|
||||
const { getFile, isString } = require('./utils.js');
|
||||
const { env } = require('./env.js');
|
||||
import fs from 'fs';
|
||||
import { getFile, isString } from './utils.js';
|
||||
import { env } from './env.js';
|
||||
|
||||
let CanvasClass;
|
||||
let ImageClass = typeof Image !== 'undefined' ? Image : null; // Only used for type-checking
|
||||
let ImageClass = typeof Image !== 'undefined' ? Image : null;
|
||||
|
||||
let ImageDataClass;
|
||||
let loadImageFunction;
|
||||
|
@ -14,7 +14,7 @@ if (typeof self !== 'undefined') {
|
|||
ImageDataClass = ImageData;
|
||||
|
||||
} else {
|
||||
const { Canvas, loadImage, ImageData, Image } = require('canvas');
|
||||
const { Canvas, loadImage, ImageData, Image } = await import('canvas');
|
||||
CanvasClass = Canvas;
|
||||
loadImageFunction = async (/**@type {Blob}*/ b) => await loadImage(Buffer.from(await b.arrayBuffer()));
|
||||
ImageDataClass = ImageData;
|
||||
|
@ -22,7 +22,7 @@ if (typeof self !== 'undefined') {
|
|||
}
|
||||
|
||||
|
||||
class CustomImage {
|
||||
export class CustomImage {
|
||||
|
||||
/**
|
||||
* Create a new CustomImage object.
|
||||
|
@ -277,7 +277,3 @@ class CustomImage {
|
|||
fs.writeFileSync(path, buffer);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
CustomImage,
|
||||
};
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
/**
|
||||
* @param {TypedArray} input
|
||||
*/
|
||||
function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
|
||||
export function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
|
||||
// TODO use mode and align_corners
|
||||
|
||||
// Output image dimensions
|
||||
|
@ -86,7 +86,7 @@ function interpolate(input, [in_channels, in_height, in_width], [out_height, out
|
|||
* @param {number[]} axes
|
||||
* @returns {[T, number[]]} The transposed array and the new shape.
|
||||
*/
|
||||
function transpose_data(array, dims, axes) {
|
||||
export function transpose_data(array, dims, axes) {
|
||||
// Calculate the new shape of the transposed array
|
||||
// and the stride of the original array
|
||||
const shape = new Array(axes.length);
|
||||
|
@ -117,8 +117,3 @@ function transpose_data(array, dims, axes) {
|
|||
|
||||
return [transposedData, shape];
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
interpolate,
|
||||
transpose: transpose_data,
|
||||
}
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
const {
|
||||
import {
|
||||
Callable,
|
||||
getModelFile,
|
||||
fetchJSON,
|
||||
dispatchCallback,
|
||||
isIntegralNumber,
|
||||
} = require("./utils.js");
|
||||
} from './utils.js';
|
||||
|
||||
const {
|
||||
import {
|
||||
Sampler,
|
||||
} = require("./samplers.js");
|
||||
} from './samplers.js';
|
||||
|
||||
|
||||
const {
|
||||
import {
|
||||
LogitsProcessorList,
|
||||
GenerationConfig,
|
||||
ForceTokensLogitsProcessor,
|
||||
|
@ -20,13 +19,14 @@ const {
|
|||
WhisperTimeStampLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor
|
||||
} = require("./generation.js");
|
||||
} from './generation.js';
|
||||
|
||||
const { executionProviders, ONNX } = require('./backends/onnx.js');
|
||||
const {
|
||||
import {
|
||||
Tensor,
|
||||
cat
|
||||
} = require('./tensor_utils');
|
||||
} from './tensor_utils.js';
|
||||
|
||||
import { executionProviders, ONNX } from './backends/onnx.js';
|
||||
const { InferenceSession, Tensor: ONNXTensor } = ONNX;
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
|
@ -2172,7 +2172,7 @@ class MarianMTModel extends MarianPreTrainedModel {
|
|||
/**
|
||||
* Helper class to determine model type from config
|
||||
*/
|
||||
class AutoModel {
|
||||
export class AutoModel {
|
||||
// Helper class to determine model type from config
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'bert': BertModel,
|
||||
|
@ -2223,7 +2223,7 @@ class AutoModel {
|
|||
/**
|
||||
* Helper class for loading sequence classification models from pretrained checkpoints
|
||||
*/
|
||||
class AutoModelForSequenceClassification {
|
||||
export class AutoModelForSequenceClassification {
|
||||
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'bert': BertForSequenceClassification,
|
||||
|
@ -2267,7 +2267,7 @@ class AutoModelForSequenceClassification {
|
|||
/**
|
||||
* Helper class for loading token classification models from pretrained checkpoints
|
||||
*/
|
||||
class AutoModelForTokenClassification {
|
||||
export class AutoModelForTokenClassification {
|
||||
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'bert': BertForTokenClassification,
|
||||
|
@ -2306,7 +2306,7 @@ class AutoModelForTokenClassification {
|
|||
/**
|
||||
* Class representing an automatic sequence-to-sequence language model.
|
||||
*/
|
||||
class AutoModelForSeq2SeqLM {
|
||||
export class AutoModelForSeq2SeqLM {
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
't5': T5ForConditionalGeneration,
|
||||
'mt5': MT5ForConditionalGeneration,
|
||||
|
@ -2337,7 +2337,7 @@ class AutoModelForSeq2SeqLM {
|
|||
/**
|
||||
* A class for loading pre-trained models for causal language modeling tasks.
|
||||
*/
|
||||
class AutoModelForCausalLM {
|
||||
export class AutoModelForCausalLM {
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'gpt2': GPT2LMHeadModel,
|
||||
'gpt_neo': GPTNeoForCausalLM,
|
||||
|
@ -2376,7 +2376,7 @@ class AutoModelForCausalLM {
|
|||
/**
|
||||
* A class to automatically select the appropriate model for Masked Language Modeling (MLM) tasks.
|
||||
*/
|
||||
class AutoModelForMaskedLM {
|
||||
export class AutoModelForMaskedLM {
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'bert': BertForMaskedLM,
|
||||
'albert': AlbertForMaskedLM,
|
||||
|
@ -2419,7 +2419,7 @@ class AutoModelForMaskedLM {
|
|||
/**
|
||||
* Automatic model class for question answering tasks.
|
||||
*/
|
||||
class AutoModelForQuestionAnswering {
|
||||
export class AutoModelForQuestionAnswering {
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'bert': BertForQuestionAnswering,
|
||||
'albert': AlbertForQuestionAnswering,
|
||||
|
@ -2460,7 +2460,7 @@ class AutoModelForQuestionAnswering {
|
|||
/**
|
||||
* Class representing an autoencoder-decoder model for vision-to-sequence tasks.
|
||||
*/
|
||||
class AutoModelForVision2Seq {
|
||||
export class AutoModelForVision2Seq {
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'vision-encoder-decoder': VisionEncoderDecoderModel
|
||||
}
|
||||
|
@ -2498,7 +2498,7 @@ class AutoModelForVision2Seq {
|
|||
/**
|
||||
* AutoModelForImageClassification is a class for loading pre-trained image classification models from ONNX format.
|
||||
*/
|
||||
class AutoModelForImageClassification {
|
||||
export class AutoModelForImageClassification {
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'vit': ViTForImageClassification,
|
||||
}
|
||||
|
@ -2537,7 +2537,7 @@ class AutoModelForImageClassification {
|
|||
/**
|
||||
* AutoModelForImageSegmentation is a class for loading pre-trained image classification models from ONNX format.
|
||||
*/
|
||||
class AutoModelForImageSegmentation {
|
||||
export class AutoModelForImageSegmentation {
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'detr': DetrForSegmentation,
|
||||
}
|
||||
|
@ -2573,7 +2573,7 @@ class AutoModelForImageSegmentation {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
class AutoModelForObjectDetection {
|
||||
export class AutoModelForObjectDetection {
|
||||
static MODEL_CLASS_MAPPING = {
|
||||
'detr': DetrForObjectDetection,
|
||||
}
|
||||
|
@ -2664,17 +2664,3 @@ class QuestionAnsweringModelOutput extends ModelOutput {
|
|||
this.end_logits = end_logits;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForImageSegmentation,
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
const {
|
||||
import {
|
||||
Callable,
|
||||
softmax,
|
||||
indexOfMax,
|
||||
|
@ -8,12 +8,12 @@ const {
|
|||
isString,
|
||||
getFile,
|
||||
dot
|
||||
} = require("./utils.js");
|
||||
} from './utils.js';
|
||||
|
||||
const {
|
||||
import {
|
||||
AutoTokenizer
|
||||
} = require("./tokenizers.js");
|
||||
const {
|
||||
} from './tokenizers.js';
|
||||
import {
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
|
@ -25,19 +25,15 @@ const {
|
|||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForObjectDetection
|
||||
} = require("./models.js");
|
||||
const {
|
||||
} from './models.js';
|
||||
import {
|
||||
AutoProcessor,
|
||||
Processor
|
||||
} = require("./processors.js");
|
||||
} from './processors.js';
|
||||
|
||||
|
||||
const {
|
||||
env
|
||||
} = require('./env.js');
|
||||
|
||||
const { Tensor, transpose_data } = require("./tensor_utils.js");
|
||||
const { CustomImage } = require("./image_utils.js");
|
||||
import { env } from './env.js';
|
||||
import { Tensor } from './tensor_utils.js';
|
||||
import { CustomImage } from './image_utils.js';
|
||||
|
||||
/**
|
||||
* Prepare images for further tasks.
|
||||
|
@ -1364,7 +1360,7 @@ const TASK_ALIASES = {
|
|||
* @todo fix error below
|
||||
* @throws {Error} If an unsupported pipeline is requested.
|
||||
*/
|
||||
async function pipeline(
|
||||
export async function pipeline(
|
||||
task,
|
||||
model = null,
|
||||
{
|
||||
|
@ -1454,7 +1450,3 @@ function product(...a) {
|
|||
// Adapted from https://stackoverflow.com/a/43053803
|
||||
return a.reduce((a, b) => a.flatMap(d => b.map(e => [d, e])));
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
pipeline
|
||||
};
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
|
||||
const {
|
||||
import {
|
||||
Callable,
|
||||
fetchJSON,
|
||||
indexOfMax,
|
||||
softmax,
|
||||
} = require("./utils.js");
|
||||
} from './utils.js';
|
||||
|
||||
import FFT from './fft.js';
|
||||
import { Tensor, transpose, cat, interpolate } from './tensor_utils.js';
|
||||
|
||||
const FFT = require('./fft.js');
|
||||
const { Tensor, transpose, cat, interpolate } = require("./tensor_utils.js");
|
||||
import { CustomImage } from './image_utils.js';
|
||||
|
||||
const { CustomImage } = require('./image_utils.js');
|
||||
/**
|
||||
* Helper class to determine model type from config
|
||||
*/
|
||||
class AutoProcessor {
|
||||
export class AutoProcessor {
|
||||
/**
|
||||
* Returns a new instance of a Processor with a feature extractor
|
||||
* based on the configuration file located at `modelPath`.
|
||||
|
@ -938,7 +938,7 @@ class WhisperFeatureExtractor extends FeatureExtractor {
|
|||
* Represents a Processor that extracts features from an input.
|
||||
* @extends Callable
|
||||
*/
|
||||
class Processor extends Callable {
|
||||
export class Processor extends Callable {
|
||||
/**
|
||||
* Creates a new Processor with the given feature extractor.
|
||||
* @param {function} feature_extractor - The function used to extract features from the input.
|
||||
|
@ -975,9 +975,3 @@ class WhisperProcessor extends Processor {
|
|||
return await this.feature_extractor(audio)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
module.exports = {
|
||||
AutoProcessor,
|
||||
Processor,
|
||||
}
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
const {
|
||||
import {
|
||||
Callable,
|
||||
indexOfMax,
|
||||
softmax,
|
||||
log_softmax,
|
||||
getTopItems
|
||||
} = require("./utils.js");
|
||||
} from './utils.js';
|
||||
|
||||
/**
|
||||
* Sampler is a base class for all sampling methods used for text generation.
|
||||
*/
|
||||
class Sampler extends Callable {
|
||||
export class Sampler extends Callable {
|
||||
/**
|
||||
* Creates a new Sampler object with the specified temperature.
|
||||
* @param {number} temperature - The temperature to use when sampling. Higher values result in more random samples.
|
||||
|
@ -240,10 +240,3 @@ class BeamSearchSampler extends Sampler {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
Sampler,
|
||||
GreedySampler,
|
||||
TopKSampler,
|
||||
BeamSearchSampler
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
const { ONNX } = require('./backends/onnx.js');
|
||||
import { ONNX } from './backends/onnx.js';
|
||||
|
||||
const { interpolate: interpolate_data, transpose: transpose_data } = require('./math_utils.js');
|
||||
import {
|
||||
interpolate as interpolate_data,
|
||||
transpose_data
|
||||
} from './math_utils.js';
|
||||
|
||||
|
||||
/**
|
||||
|
@ -10,7 +13,7 @@ const { interpolate: interpolate_data, transpose: transpose_data } = require('./
|
|||
const ONNXTensor = ONNX.Tensor;
|
||||
|
||||
// TODO: fix error below
|
||||
class Tensor extends ONNXTensor {
|
||||
export class Tensor extends ONNXTensor {
|
||||
/**
|
||||
* Create a new Tensor or copy an existing Tensor.
|
||||
* @param {[string, Array|AnyTypedArray, number[]]|[ONNXTensor]} args
|
||||
|
@ -190,7 +193,7 @@ function reshape(data, dimensions) {
|
|||
* @param {Array} axes - The axes to transpose the tensor along.
|
||||
* @returns {Tensor} The transposed tensor.
|
||||
*/
|
||||
function transpose(tensor, axes) {
|
||||
export function transpose(tensor, axes) {
|
||||
const [transposedData, shape] = transpose_data(tensor.data, tensor.dims, axes);
|
||||
return new Tensor(tensor.type, transposedData, shape);
|
||||
}
|
||||
|
@ -202,7 +205,7 @@ function transpose(tensor, axes) {
|
|||
* @param {any} tensors - The array of tensors to concatenate.
|
||||
* @returns {Tensor} - The concatenated tensor.
|
||||
*/
|
||||
function cat(tensors) {
|
||||
export function cat(tensors) {
|
||||
if (tensors.length === 0) {
|
||||
return tensors[0];
|
||||
}
|
||||
|
@ -241,7 +244,7 @@ function cat(tensors) {
|
|||
* @param {boolean} align_corners - Whether to align corners.
|
||||
* @returns {Tensor} - The interpolated tensor.
|
||||
*/
|
||||
function interpolate(input, [out_height, out_width], mode = 'bilinear', align_corners = false) {
|
||||
export function interpolate(input, [out_height, out_width], mode = 'bilinear', align_corners = false) {
|
||||
|
||||
// Input image dimensions
|
||||
const in_channels = input.dims.at(-3) ?? 1;
|
||||
|
@ -257,11 +260,3 @@ function interpolate(input, [out_height, out_width], mode = 'bilinear', align_co
|
|||
);
|
||||
return new Tensor(input.type, output, [in_channels, out_height, out_width]);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
Tensor,
|
||||
transpose,
|
||||
cat,
|
||||
interpolate,
|
||||
transpose_data,
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
const {
|
||||
import {
|
||||
Callable,
|
||||
fetchJSON,
|
||||
reverseDictionary,
|
||||
escapeRegExp,
|
||||
isIntegralNumber,
|
||||
min,
|
||||
} = require('./utils.js');
|
||||
} from './utils.js';
|
||||
|
||||
const { Tensor } = require('./tensor_utils.js')
|
||||
import { Tensor } from './tensor_utils.js';
|
||||
|
||||
/**
|
||||
* Abstract base class for tokenizer models.
|
||||
|
@ -1837,7 +1837,7 @@ function bert_prepare_model_inputs(inputs) {
|
|||
* BertTokenizer is a class used to tokenize text for BERT models.
|
||||
* @extends PreTrainedTokenizer
|
||||
*/
|
||||
class BertTokenizer extends PreTrainedTokenizer {
|
||||
export class BertTokenizer extends PreTrainedTokenizer {
|
||||
/**
|
||||
* @see {@link bert_prepare_model_inputs}
|
||||
*/
|
||||
|
@ -1849,7 +1849,7 @@ class BertTokenizer extends PreTrainedTokenizer {
|
|||
* Albert tokenizer
|
||||
* @extends PreTrainedTokenizer
|
||||
*/
|
||||
class AlbertTokenizer extends PreTrainedTokenizer {
|
||||
export class AlbertTokenizer extends PreTrainedTokenizer {
|
||||
/**
|
||||
* @see {@link bert_prepare_model_inputs}
|
||||
*/
|
||||
|
@ -1857,7 +1857,7 @@ class AlbertTokenizer extends PreTrainedTokenizer {
|
|||
return bert_prepare_model_inputs(inputs);
|
||||
}
|
||||
}
|
||||
class MobileBertTokenizer extends PreTrainedTokenizer {
|
||||
export class MobileBertTokenizer extends PreTrainedTokenizer {
|
||||
/**
|
||||
* @see {@link bert_prepare_model_inputs}
|
||||
*/
|
||||
|
@ -1865,7 +1865,7 @@ class MobileBertTokenizer extends PreTrainedTokenizer {
|
|||
return bert_prepare_model_inputs(inputs);
|
||||
}
|
||||
}
|
||||
class SqueezeBertTokenizer extends PreTrainedTokenizer {
|
||||
export class SqueezeBertTokenizer extends PreTrainedTokenizer {
|
||||
/**
|
||||
* @see {@link bert_prepare_model_inputs}
|
||||
*/
|
||||
|
@ -1873,18 +1873,18 @@ class SqueezeBertTokenizer extends PreTrainedTokenizer {
|
|||
return bert_prepare_model_inputs(inputs);
|
||||
}
|
||||
}
|
||||
class DistilBertTokenizer extends PreTrainedTokenizer { }
|
||||
class T5Tokenizer extends PreTrainedTokenizer { }
|
||||
class GPT2Tokenizer extends PreTrainedTokenizer { }
|
||||
class BartTokenizer extends PreTrainedTokenizer { }
|
||||
class RobertaTokenizer extends PreTrainedTokenizer { }
|
||||
export class DistilBertTokenizer extends PreTrainedTokenizer { }
|
||||
export class T5Tokenizer extends PreTrainedTokenizer { }
|
||||
export class GPT2Tokenizer extends PreTrainedTokenizer { }
|
||||
export class BartTokenizer extends PreTrainedTokenizer { }
|
||||
export class RobertaTokenizer extends PreTrainedTokenizer { }
|
||||
|
||||
|
||||
/**
|
||||
* WhisperTokenizer tokenizer
|
||||
* @extends PreTrainedTokenizer
|
||||
*/
|
||||
class WhisperTokenizer extends PreTrainedTokenizer {
|
||||
export class WhisperTokenizer extends PreTrainedTokenizer {
|
||||
static LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
|
@ -2297,9 +2297,9 @@ class WhisperTokenizer extends PreTrainedTokenizer {
|
|||
return totalSequence;
|
||||
}
|
||||
}
|
||||
class CodeGenTokenizer extends PreTrainedTokenizer { }
|
||||
class CLIPTokenizer extends PreTrainedTokenizer { }
|
||||
class MarianTokenizer extends PreTrainedTokenizer {
|
||||
export class CodeGenTokenizer extends PreTrainedTokenizer { }
|
||||
export class CLIPTokenizer extends PreTrainedTokenizer { }
|
||||
export class MarianTokenizer extends PreTrainedTokenizer {
|
||||
/**
|
||||
* Create a new MarianTokenizer instance.
|
||||
* @param {Object} tokenizerJSON - The JSON of the tokenizer.
|
||||
|
@ -2568,7 +2568,7 @@ class TokenLatticeNode {
|
|||
}
|
||||
}
|
||||
|
||||
class AutoTokenizer {
|
||||
export class AutoTokenizer {
|
||||
// Helper class to determine tokenizer type from tokenizer.json
|
||||
static TOKENIZER_CLASS_MAPPING = {
|
||||
'T5Tokenizer': T5Tokenizer,
|
||||
|
@ -2601,11 +2601,3 @@ class AutoTokenizer {
|
|||
return new cls(tokenizerJSON, tokenizerConfig);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
DistilBertTokenizer,
|
||||
T5Tokenizer,
|
||||
GPT2Tokenizer
|
||||
};
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
|
||||
const {
|
||||
// Tokenizers
|
||||
export {
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
DistilBertTokenizer,
|
||||
T5Tokenizer,
|
||||
GPT2Tokenizer
|
||||
} = require("./tokenizers.js");
|
||||
const {
|
||||
} from './tokenizers.js';
|
||||
|
||||
// Models
|
||||
export {
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
|
@ -17,54 +20,18 @@ const {
|
|||
AutoModelForVision2Seq,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForObjectDetection,
|
||||
} = require("./models.js");
|
||||
} from './models.js';
|
||||
|
||||
const {
|
||||
// Processors
|
||||
export {
|
||||
AutoProcessor
|
||||
} = require("./processors.js");
|
||||
const {
|
||||
} from './processors.js';
|
||||
|
||||
// environment variables
|
||||
export { env } from './env.js';
|
||||
|
||||
// other
|
||||
export {
|
||||
pipeline
|
||||
} = require("./pipelines.js");
|
||||
const { env } = require('./env.js');
|
||||
|
||||
const { Tensor } = require('./tensor_utils.js');
|
||||
|
||||
const moduleExports = {
|
||||
// Tokenizers
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
DistilBertTokenizer,
|
||||
T5Tokenizer,
|
||||
GPT2Tokenizer,
|
||||
|
||||
// Models
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForObjectDetection,
|
||||
|
||||
// Processors
|
||||
AutoProcessor,
|
||||
|
||||
// other
|
||||
pipeline,
|
||||
Tensor,
|
||||
|
||||
// environment variables
|
||||
env
|
||||
};
|
||||
|
||||
// Allow global access to these variables
|
||||
if (typeof self !== 'undefined') {
|
||||
// Used by web workers
|
||||
Object.assign(self, moduleExports);
|
||||
}
|
||||
|
||||
// Used by other modules
|
||||
module.exports = moduleExports
|
||||
} from './pipelines.js';
|
||||
export { Tensor } from './tensor_utils.js';
|
||||
|
|
75
src/utils.js
75
src/utils.js
|
@ -1,12 +1,12 @@
|
|||
|
||||
const fs = require('fs');
|
||||
import { existsSync, statSync, promises } from 'fs';
|
||||
|
||||
const { env } = require('./env.js');
|
||||
import { env } from './env.js';
|
||||
|
||||
if (global.ReadableStream === undefined && typeof process !== 'undefined') {
|
||||
try {
|
||||
// @ts-ignore
|
||||
global.ReadableStream = require('node:stream/web').ReadableStream; // ReadableStream is not a global with Node 16
|
||||
global.ReadableStream = (await import('node:stream/web')).ReadableStream; // ReadableStream is not a global with Node 16
|
||||
} catch (err) {
|
||||
console.warn("ReadableStream not defined and unable to import from node:stream/web");
|
||||
}
|
||||
|
@ -22,12 +22,12 @@ class FileResponse {
|
|||
this.headers = {};
|
||||
this.headers.get = (x) => this.headers[x]
|
||||
|
||||
this.exists = fs.existsSync(filePath);
|
||||
this.exists = existsSync(filePath);
|
||||
if (this.exists) {
|
||||
this.status = 200;
|
||||
this.statusText = 'OK';
|
||||
|
||||
let stats = fs.statSync(filePath);
|
||||
let stats = statSync(filePath);
|
||||
this.headers['content-length'] = stats.size;
|
||||
|
||||
this.updateContentType();
|
||||
|
@ -111,7 +111,7 @@ class FileResponse {
|
|||
* @throws {Error} - If the file cannot be read.
|
||||
*/
|
||||
async arrayBuffer() {
|
||||
const data = await fs.promises.readFile(this.filePath);
|
||||
const data = await promises.readFile(this.filePath);
|
||||
return data.buffer;
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ class FileResponse {
|
|||
* @throws {Error} - If the file cannot be read.
|
||||
*/
|
||||
async blob() {
|
||||
const data = await fs.promises.readFile(this.filePath);
|
||||
const data = await promises.readFile(this.filePath);
|
||||
return new Blob([data], { type: this.headers['content-type'] });
|
||||
}
|
||||
|
||||
|
@ -137,7 +137,7 @@ class FileResponse {
|
|||
* @throws {Error} - If the file cannot be read.
|
||||
*/
|
||||
async text() {
|
||||
const data = await fs.promises.readFile(this.filePath, 'utf8');
|
||||
const data = await promises.readFile(this.filePath, 'utf8');
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -179,7 +179,7 @@ function isValidHttpUrl(string) {
|
|||
* @param {string|URL} url - The URL of the file to get.
|
||||
* @returns {Promise<FileResponse|Response>} A promise that resolves to a FileResponse object (if the file is retrieved using the FileSystem API), or a Response object (if the file is retrieved using the Fetch API).
|
||||
*/
|
||||
async function getFile(url) {
|
||||
export async function getFile(url) {
|
||||
// Helper function to get a file, using either the Fetch API or FileSystem API
|
||||
|
||||
if (env.useFS && !isValidHttpUrl(url)) {
|
||||
|
@ -198,7 +198,7 @@ async function getFile(url) {
|
|||
* @param {any} data - The data to pass to the progress callback function.
|
||||
* @returns {void}
|
||||
*/
|
||||
function dispatchCallback(progressCallback, data) {
|
||||
export function dispatchCallback(progressCallback, data) {
|
||||
if (progressCallback !== null) progressCallback(data);
|
||||
}
|
||||
|
||||
|
@ -293,7 +293,7 @@ async function getModelFile(modelPath, fileName, progressCallback = null, fatal
|
|||
* @param {function} progressCallback - A callback function to receive progress updates. Optional.
|
||||
* @returns {Promise<object>} - The JSON data parsed into a JavaScript object.
|
||||
*/
|
||||
async function fetchJSON(modelPath, fileName, progressCallback = null, fatal = true) {
|
||||
export async function fetchJSON(modelPath, fileName, progressCallback = null, fatal = true) {
|
||||
let buffer = await getModelFile(modelPath, fileName, progressCallback, fatal);
|
||||
if (buffer === null) {
|
||||
// Return empty object
|
||||
|
@ -369,7 +369,7 @@ async function readResponse(response, progressCallback) {
|
|||
* @param {...string} parts - Multiple parts of a path.
|
||||
* @returns {string} A string representing the joined path.
|
||||
*/
|
||||
function pathJoin(...parts) {
|
||||
export function pathJoin(...parts) {
|
||||
// https://stackoverflow.com/a/55142565
|
||||
parts = parts.map((part, index) => {
|
||||
if (index) {
|
||||
|
@ -390,7 +390,7 @@ function pathJoin(...parts) {
|
|||
* @returns {object} The reversed object.
|
||||
* @see https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript
|
||||
*/
|
||||
function reverseDictionary(data) {
|
||||
export function reverseDictionary(data) {
|
||||
// https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript
|
||||
return Object.fromEntries(Object.entries(data).map(([key, value]) => [value, key]));
|
||||
}
|
||||
|
@ -401,7 +401,7 @@ function reverseDictionary(data) {
|
|||
* @see https://stackoverflow.com/a/11301464
|
||||
* @returns {number} - The index of the maximum value in the array.
|
||||
*/
|
||||
function indexOfMax(arr) {
|
||||
export function indexOfMax(arr) {
|
||||
// https://stackoverflow.com/a/11301464
|
||||
|
||||
if (arr.length === 0) {
|
||||
|
@ -427,7 +427,7 @@ function indexOfMax(arr) {
|
|||
* @param {number[]} arr - The array of numbers to compute the softmax of.
|
||||
* @returns {number[]} The softmax array.
|
||||
*/
|
||||
function softmax(arr) {
|
||||
export function softmax(arr) {
|
||||
// Compute the maximum value in the array
|
||||
const maxVal = max(arr);
|
||||
|
||||
|
@ -448,7 +448,7 @@ function softmax(arr) {
|
|||
* @param {number[]} arr - The input array to calculate the log_softmax function for.
|
||||
* @returns {any} - The resulting log_softmax array.
|
||||
*/
|
||||
function log_softmax(arr) {
|
||||
export function log_softmax(arr) {
|
||||
// Compute the softmax values
|
||||
const softmaxArr = softmax(arr);
|
||||
|
||||
|
@ -464,7 +464,7 @@ function log_softmax(arr) {
|
|||
* @param {string} string - The string to escape.
|
||||
* @returns {string} - The escaped string.
|
||||
*/
|
||||
function escapeRegExp(string) {
|
||||
export function escapeRegExp(string) {
|
||||
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
|
||||
}
|
||||
|
||||
|
@ -475,7 +475,7 @@ function escapeRegExp(string) {
|
|||
* @param {number} [top_k=0] - The number of top items to return (default: 0 = return all)
|
||||
* @returns {Array} - The top k items, sorted by descending order
|
||||
*/
|
||||
function getTopItems(items, top_k = 0) {
|
||||
export function getTopItems(items, top_k = 0) {
|
||||
// if top == 0, return all
|
||||
|
||||
items = Array.from(items)
|
||||
|
@ -495,7 +495,7 @@ function getTopItems(items, top_k = 0) {
|
|||
* @param {number[]} arr2 - The second array.
|
||||
* @returns {number} - The dot product of arr1 and arr2.
|
||||
*/
|
||||
function dot(arr1, arr2) {
|
||||
export function dot(arr1, arr2) {
|
||||
return arr1.reduce((acc, val, i) => acc + val * arr2[i], 0);
|
||||
}
|
||||
|
||||
|
@ -506,7 +506,7 @@ function dot(arr1, arr2) {
|
|||
* @param {number[]} arr2 - The second array.
|
||||
* @returns {number} The cosine similarity between the two arrays.
|
||||
*/
|
||||
function cos_sim(arr1, arr2) {
|
||||
export function cos_sim(arr1, arr2) {
|
||||
// Calculate dot product of the two arrays
|
||||
const dotProduct = dot(arr1, arr2);
|
||||
|
||||
|
@ -527,7 +527,7 @@ function cos_sim(arr1, arr2) {
|
|||
* @param {number[]} arr - The array to calculate the magnitude of.
|
||||
* @returns {number} The magnitude of the array.
|
||||
*/
|
||||
function magnitude(arr) {
|
||||
export function magnitude(arr) {
|
||||
return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0));
|
||||
}
|
||||
|
||||
|
@ -536,7 +536,7 @@ function magnitude(arr) {
|
|||
*
|
||||
* @extends Function
|
||||
*/
|
||||
class Callable extends Function {
|
||||
export class Callable extends Function {
|
||||
/**
|
||||
* Creates a new instance of the Callable class.
|
||||
*/
|
||||
|
@ -573,7 +573,7 @@ class Callable extends Function {
|
|||
* @returns {number} - the minimum number.
|
||||
* @throws {Error} If array is empty.
|
||||
*/
|
||||
function min(arr) {
|
||||
export function min(arr) {
|
||||
if (arr.length === 0) throw Error('Array must not be empty');
|
||||
let min = arr[0];
|
||||
for (let i = 1; i < arr.length; ++i) {
|
||||
|
@ -591,7 +591,7 @@ function min(arr) {
|
|||
* @returns {number} - the maximum number.
|
||||
* @throws {Error} If array is empty.
|
||||
*/
|
||||
function max(arr) {
|
||||
export function max(arr) {
|
||||
if (arr.length === 0) throw Error('Array must not be empty');
|
||||
let max = arr[0];
|
||||
for (let i = 1; i < arr.length; ++i) {
|
||||
|
@ -607,7 +607,7 @@ function max(arr) {
|
|||
* @param {*} text - The value to check.
|
||||
* @returns {boolean} - True if the value is a string, false otherwise.
|
||||
*/
|
||||
function isString(text) {
|
||||
export function isString(text) {
|
||||
return typeof text === 'string' || text instanceof String
|
||||
}
|
||||
|
||||
|
@ -616,7 +616,7 @@ function isString(text) {
|
|||
* @param {*} x - The value to check.
|
||||
* @returns {boolean} - True if the value is a string, false otherwise.
|
||||
*/
|
||||
function isIntegralNumber(x) {
|
||||
export function isIntegralNumber(x) {
|
||||
return Number.isInteger(x) || typeof x === 'bigint'
|
||||
}
|
||||
|
||||
|
@ -625,29 +625,10 @@ function isIntegralNumber(x) {
|
|||
* @param {*} x - The value to check.
|
||||
* @returns {boolean} - True if the value exists, false otherwise.
|
||||
*/
|
||||
function exists(x) {
|
||||
export function exists(x) {
|
||||
return x !== undefined && x !== null;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
Callable,
|
||||
export {
|
||||
getModelFile,
|
||||
dispatchCallback,
|
||||
fetchJSON,
|
||||
pathJoin,
|
||||
reverseDictionary,
|
||||
indexOfMax,
|
||||
softmax,
|
||||
log_softmax,
|
||||
escapeRegExp,
|
||||
getTopItems,
|
||||
dot,
|
||||
cos_sim,
|
||||
magnitude,
|
||||
getFile,
|
||||
isIntegralNumber,
|
||||
isString,
|
||||
exists,
|
||||
min,
|
||||
max,
|
||||
};
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
|
||||
const path = require('path');
|
||||
const { pipeline, env } = require('..');
|
||||
import path from 'path';
|
||||
import { pipeline, env } from '../src/transformers.js';
|
||||
|
||||
const __dirname = env.__dirname;
|
||||
|
||||
// Only use local models
|
||||
env.remoteModels = false;
|
||||
|
@ -757,7 +759,7 @@ async function image_classification() {
|
|||
async function image_segmentation() {
|
||||
let segmenter = await pipeline('image-segmentation', 'facebook/detr-resnet-50-panoptic')
|
||||
|
||||
let img = path.join(__dirname, '../assets/images/cats.jpg')
|
||||
let img = path.join(__dirname, './assets/images/cats.jpg')
|
||||
|
||||
let start = performance.now();
|
||||
let outputs = await segmenter(img);
|
||||
|
|
Loading…
Reference in New Issue