Replace CommonJS imports/exports with ES6

This commit is contained in:
Joshua Lochner 2023-04-13 21:01:44 +02:00
parent 1b9fe0f27c
commit 421aa331cd
17 changed files with 164 additions and 281 deletions

View File

@ -1,5 +1,8 @@
{
"compilerOptions": {
"checkJs": true
"checkJs": true,
"target": "esnext",
"module": "esnext",
"moduleResolution": "nodenext",
}
}

View File

@ -3,6 +3,7 @@
"version": "1.4.2",
"description": "Run 🤗 Transformers in your browser! We currently support BERT, ALBERT, DistilBERT, MobileBERT, SqueezeBERT, T5, T5v1.1, FLAN-T5, mT5, BART, MarianMT, GPT2, GPT Neo, CodeGen, Whisper, CLIP, Vision Transformer, VisionEncoderDecoder, and DETR models, for a variety of tasks including: masked language modelling, text classification, token classification, zero-shot classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, image segmentation, and object detection.",
"main": "./src/transformers.js",
"type": "module",
"directories": {
"test": "tests"
},

View File

@ -1,13 +1,13 @@
let ONNX;
export let ONNX;
// TODO support more execution providers (e.g., webgpu)
const executionProviders = ['wasm'];
export const executionProviders = ['wasm'];
if (typeof process !== 'undefined') {
// Running in a node-like environment.
// Try to import onnxruntime-node, using onnxruntime-web as a fallback
try {
ONNX = require('onnxruntime-node');
ONNX = (await import('onnxruntime-node')).default;
} catch (err) {
console.warn(
"Node.js environment detected, but `onnxruntime-node` was not found. " +
@ -20,7 +20,7 @@ if (typeof process !== 'undefined') {
// @ts-ignore
global.self = global;
ONNX = require('onnxruntime-web');
ONNX = (await import('onnxruntime-web')).default;
// Disable spawning worker threads for testing.
// This is done by setting numThreads to 1
@ -33,10 +33,6 @@ if (typeof process !== 'undefined') {
} else {
// Running in a browser-environment, so we just import `onnxruntime-web`
ONNX = require('onnxruntime-web');
ONNX = (await import('onnxruntime-web')).default;
}
module.exports = {
ONNX,
executionProviders,
}

View File

@ -1,7 +1,11 @@
const fs = require('fs');
const path = require('path');
import fs from 'fs';
import path from 'path';
import { fileURLToPath } from 'url';
const { env: onnx_env } = require('./backends/onnx.js').ONNX;
import { ONNX } from './backends/onnx.js';
const { env: onnx_env } = ONNX;
const __dirname = path.dirname(path.dirname(fileURLToPath(import.meta.url)));
// check if various APIs are available (depends on environment)
const CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self;
@ -13,7 +17,7 @@ const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE;
// set local model path, based on available APIs
const DEFAULT_LOCAL_PATH = '/models/onnx/quantized/';
const localURL = RUNNING_LOCALLY
? path.join(path.dirname(__dirname), DEFAULT_LOCAL_PATH)
? path.join(__dirname, DEFAULT_LOCAL_PATH)
: DEFAULT_LOCAL_PATH;
// First, set path to wasm files. This is needed when running in a web worker.
@ -21,12 +25,12 @@ const localURL = RUNNING_LOCALLY
// We use remote wasm files by default to make it easier for newer users.
// In practice, users should probably self-host the necessary .wasm files.
onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
? path.join(path.dirname(__dirname), '/dist/')
? path.join(__dirname, '/dist/')
: 'https://cdn.jsdelivr.net/npm/@xenova/transformers/dist/';
// Global variable used to control exection, with suitable defaults
const env = {
export const env = {
// access onnxruntime-web's environment variables
onnx: onnx_env,
@ -44,6 +48,9 @@ const env = {
// Whether to use the file system to load files. By default, it is true available.
useFS: FS_AVAILABLE,
// Directory name of module. Useful for resolving local paths.
__dirname,
}
@ -54,6 +61,3 @@ function isEmpty(obj) {
return Object.keys(obj).length === 0;
}
module.exports = {
env
}

View File

@ -3,7 +3,7 @@
* FFT class provides functionality for performing Fast Fourier Transform on arrays
* Code adapted from https://www.npmjs.com/package/fft.js
*/
class FFT {
export default class FFT {
/**
* @param {number} size - The size of the input array. Must be a power of two and bigger than 1.
* @throws {Error} FFT size must be a power of two and bigger than 1.
@ -494,5 +494,3 @@ class FFT {
out[outOff + 7] = T3r;
}
}
module.exports = FFT

View File

@ -1,9 +1,9 @@
const { Tensor } = require("./tensor_utils.js");
const {
import { Tensor } from './tensor_utils.js';
import {
Callable,
exists,
log_softmax
} = require("./utils.js");
} from './utils.js';
/**
* A class representing a list of logits processors. A logits processor is a function that modifies the logits
@ -12,7 +12,7 @@ const {
*
* @extends Callable
*/
class LogitsProcessorList extends Callable {
export class LogitsProcessorList extends Callable {
/**
* Constructs a new instance of `LogitsProcessorList`.
*/
@ -66,7 +66,7 @@ class LogitsProcessorList extends Callable {
* Base class for processing logits.
* @extends Callable
*/
class LogitsProcessor extends Callable {
export class LogitsProcessor extends Callable {
/**
* Apply the processor to the input logits.
*
@ -85,7 +85,7 @@ class LogitsProcessor extends Callable {
*
* @extends LogitsProcessor
*/
class ForceTokensLogitsProcessor extends LogitsProcessor {
export class ForceTokensLogitsProcessor extends LogitsProcessor {
/**
* Constructs a new instance of `ForceTokensLogitsProcessor`.
*
@ -117,7 +117,7 @@ class ForceTokensLogitsProcessor extends LogitsProcessor {
* A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
* @extends LogitsProcessor
*/
class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
/**
* Create a ForcedBOSTokenLogitsProcessor.
* @param {number} bos_token_id - The ID of the beginning-of-sequence token to be forced.
@ -146,7 +146,7 @@ class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
*
* @extends LogitsProcessor
*/
class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
/**
* Create a ForcedEOSTokenLogitsProcessor.
* @param {number} max_length - Max length of the sequence.
@ -174,7 +174,7 @@ class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
* A LogitsProcessor that handles adding timestamps to generated text.
* @extends LogitsProcessor
*/
class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
/**
* Constructs a new WhisperTimeStampLogitsProcessor.
* @param {object} generate_config - The config object passed to the `generate()` method of a transformer model.
@ -249,7 +249,7 @@ class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
*
* @extends LogitsProcessor
*/
class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
/**
* Create a NoRepeatNGramLogitsProcessor.
* @param {number} no_repeat_ngram_size - The no-repeat-ngram size. All ngrams of this size can only occur once.
@ -340,7 +340,7 @@ class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
*
* @extends LogitsProcessor
*/
class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
/**
* Create a RepetitionPenaltyLogitsProcessor.
* @param {number} penalty - The penalty to apply for repeated tokens.
@ -372,7 +372,7 @@ class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
}
class GenerationConfig {
export class GenerationConfig {
constructor(kwargs = {}) {
// Parameters that control the length of the output
// TODO: extend the configuration with correct types
@ -465,15 +465,3 @@ class GenerationConfig {
this.generation_kwargs = kwargs.generation_kwargs ?? {};
}
}
module.exports = {
LogitsProcessor,
LogitsProcessorList,
GenerationConfig,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
WhisperTimeStampLogitsProcessor,
ForceTokensLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor
};

View File

@ -1,10 +1,10 @@
const fs = require('fs');
const { getFile, isString } = require('./utils.js');
const { env } = require('./env.js');
import fs from 'fs';
import { getFile, isString } from './utils.js';
import { env } from './env.js';
let CanvasClass;
let ImageClass = typeof Image !== 'undefined' ? Image : null; // Only used for type-checking
let ImageClass = typeof Image !== 'undefined' ? Image : null;
let ImageDataClass;
let loadImageFunction;
@ -14,7 +14,7 @@ if (typeof self !== 'undefined') {
ImageDataClass = ImageData;
} else {
const { Canvas, loadImage, ImageData, Image } = require('canvas');
const { Canvas, loadImage, ImageData, Image } = await import('canvas');
CanvasClass = Canvas;
loadImageFunction = async (/**@type {Blob}*/ b) => await loadImage(Buffer.from(await b.arrayBuffer()));
ImageDataClass = ImageData;
@ -22,7 +22,7 @@ if (typeof self !== 'undefined') {
}
class CustomImage {
export class CustomImage {
/**
* Create a new CustomImage object.
@ -277,7 +277,3 @@ class CustomImage {
fs.writeFileSync(path, buffer);
}
}
module.exports = {
CustomImage,
};

View File

@ -8,7 +8,7 @@
/**
* @param {TypedArray} input
*/
function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
export function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
// TODO use mode and align_corners
// Output image dimensions
@ -86,7 +86,7 @@ function interpolate(input, [in_channels, in_height, in_width], [out_height, out
* @param {number[]} axes
* @returns {[T, number[]]} The transposed array and the new shape.
*/
function transpose_data(array, dims, axes) {
export function transpose_data(array, dims, axes) {
// Calculate the new shape of the transposed array
// and the stride of the original array
const shape = new Array(axes.length);
@ -117,8 +117,3 @@ function transpose_data(array, dims, axes) {
return [transposedData, shape];
}
module.exports = {
interpolate,
transpose: transpose_data,
}

View File

@ -1,17 +1,16 @@
const {
import {
Callable,
getModelFile,
fetchJSON,
dispatchCallback,
isIntegralNumber,
} = require("./utils.js");
} from './utils.js';
const {
import {
Sampler,
} = require("./samplers.js");
} from './samplers.js';
const {
import {
LogitsProcessorList,
GenerationConfig,
ForceTokensLogitsProcessor,
@ -20,13 +19,14 @@ const {
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor
} = require("./generation.js");
} from './generation.js';
const { executionProviders, ONNX } = require('./backends/onnx.js');
const {
import {
Tensor,
cat
} = require('./tensor_utils');
} from './tensor_utils.js';
import { executionProviders, ONNX } from './backends/onnx.js';
const { InferenceSession, Tensor: ONNXTensor } = ONNX;
//////////////////////////////////////////////////
@ -2172,7 +2172,7 @@ class MarianMTModel extends MarianPreTrainedModel {
/**
* Helper class to determine model type from config
*/
class AutoModel {
export class AutoModel {
// Helper class to determine model type from config
static MODEL_CLASS_MAPPING = {
'bert': BertModel,
@ -2223,7 +2223,7 @@ class AutoModel {
/**
* Helper class for loading sequence classification models from pretrained checkpoints
*/
class AutoModelForSequenceClassification {
export class AutoModelForSequenceClassification {
static MODEL_CLASS_MAPPING = {
'bert': BertForSequenceClassification,
@ -2267,7 +2267,7 @@ class AutoModelForSequenceClassification {
/**
* Helper class for loading token classification models from pretrained checkpoints
*/
class AutoModelForTokenClassification {
export class AutoModelForTokenClassification {
static MODEL_CLASS_MAPPING = {
'bert': BertForTokenClassification,
@ -2306,7 +2306,7 @@ class AutoModelForTokenClassification {
/**
* Class representing an automatic sequence-to-sequence language model.
*/
class AutoModelForSeq2SeqLM {
export class AutoModelForSeq2SeqLM {
static MODEL_CLASS_MAPPING = {
't5': T5ForConditionalGeneration,
'mt5': MT5ForConditionalGeneration,
@ -2337,7 +2337,7 @@ class AutoModelForSeq2SeqLM {
/**
* A class for loading pre-trained models for causal language modeling tasks.
*/
class AutoModelForCausalLM {
export class AutoModelForCausalLM {
static MODEL_CLASS_MAPPING = {
'gpt2': GPT2LMHeadModel,
'gpt_neo': GPTNeoForCausalLM,
@ -2376,7 +2376,7 @@ class AutoModelForCausalLM {
/**
* A class to automatically select the appropriate model for Masked Language Modeling (MLM) tasks.
*/
class AutoModelForMaskedLM {
export class AutoModelForMaskedLM {
static MODEL_CLASS_MAPPING = {
'bert': BertForMaskedLM,
'albert': AlbertForMaskedLM,
@ -2419,7 +2419,7 @@ class AutoModelForMaskedLM {
/**
* Automatic model class for question answering tasks.
*/
class AutoModelForQuestionAnswering {
export class AutoModelForQuestionAnswering {
static MODEL_CLASS_MAPPING = {
'bert': BertForQuestionAnswering,
'albert': AlbertForQuestionAnswering,
@ -2460,7 +2460,7 @@ class AutoModelForQuestionAnswering {
/**
* Class representing an autoencoder-decoder model for vision-to-sequence tasks.
*/
class AutoModelForVision2Seq {
export class AutoModelForVision2Seq {
static MODEL_CLASS_MAPPING = {
'vision-encoder-decoder': VisionEncoderDecoderModel
}
@ -2498,7 +2498,7 @@ class AutoModelForVision2Seq {
/**
* AutoModelForImageClassification is a class for loading pre-trained image classification models from ONNX format.
*/
class AutoModelForImageClassification {
export class AutoModelForImageClassification {
static MODEL_CLASS_MAPPING = {
'vit': ViTForImageClassification,
}
@ -2537,7 +2537,7 @@ class AutoModelForImageClassification {
/**
* AutoModelForImageSegmentation is a class for loading pre-trained image classification models from ONNX format.
*/
class AutoModelForImageSegmentation {
export class AutoModelForImageSegmentation {
static MODEL_CLASS_MAPPING = {
'detr': DetrForSegmentation,
}
@ -2573,7 +2573,7 @@ class AutoModelForImageSegmentation {
//////////////////////////////////////////////////
class AutoModelForObjectDetection {
export class AutoModelForObjectDetection {
static MODEL_CLASS_MAPPING = {
'detr': DetrForObjectDetection,
}
@ -2664,17 +2664,3 @@ class QuestionAnsweringModelOutput extends ModelOutput {
this.end_logits = end_logits;
}
}
module.exports = {
AutoModel,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForVision2Seq,
AutoModelForImageClassification,
AutoModelForObjectDetection,
AutoModelForImageSegmentation,
};

View File

@ -1,4 +1,4 @@
const {
import {
Callable,
softmax,
indexOfMax,
@ -8,12 +8,12 @@ const {
isString,
getFile,
dot
} = require("./utils.js");
} from './utils.js';
const {
import {
AutoTokenizer
} = require("./tokenizers.js");
const {
} from './tokenizers.js';
import {
AutoModel,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
@ -25,19 +25,15 @@ const {
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForObjectDetection
} = require("./models.js");
const {
} from './models.js';
import {
AutoProcessor,
Processor
} = require("./processors.js");
} from './processors.js';
const {
env
} = require('./env.js');
const { Tensor, transpose_data } = require("./tensor_utils.js");
const { CustomImage } = require("./image_utils.js");
import { env } from './env.js';
import { Tensor } from './tensor_utils.js';
import { CustomImage } from './image_utils.js';
/**
* Prepare images for further tasks.
@ -1364,7 +1360,7 @@ const TASK_ALIASES = {
* @todo fix error below
* @throws {Error} If an unsupported pipeline is requested.
*/
async function pipeline(
export async function pipeline(
task,
model = null,
{
@ -1454,7 +1450,3 @@ function product(...a) {
// Adapted from https://stackoverflow.com/a/43053803
return a.reduce((a, b) => a.flatMap(d => b.map(e => [d, e])));
}
module.exports = {
pipeline
};

View File

@ -1,20 +1,20 @@
const {
import {
Callable,
fetchJSON,
indexOfMax,
softmax,
} = require("./utils.js");
} from './utils.js';
import FFT from './fft.js';
import { Tensor, transpose, cat, interpolate } from './tensor_utils.js';
const FFT = require('./fft.js');
const { Tensor, transpose, cat, interpolate } = require("./tensor_utils.js");
import { CustomImage } from './image_utils.js';
const { CustomImage } = require('./image_utils.js');
/**
* Helper class to determine model type from config
*/
class AutoProcessor {
export class AutoProcessor {
/**
* Returns a new instance of a Processor with a feature extractor
* based on the configuration file located at `modelPath`.
@ -938,7 +938,7 @@ class WhisperFeatureExtractor extends FeatureExtractor {
* Represents a Processor that extracts features from an input.
* @extends Callable
*/
class Processor extends Callable {
export class Processor extends Callable {
/**
* Creates a new Processor with the given feature extractor.
* @param {function} feature_extractor - The function used to extract features from the input.
@ -975,9 +975,3 @@ class WhisperProcessor extends Processor {
return await this.feature_extractor(audio)
}
}
module.exports = {
AutoProcessor,
Processor,
}

View File

@ -1,15 +1,15 @@
const {
import {
Callable,
indexOfMax,
softmax,
log_softmax,
getTopItems
} = require("./utils.js");
} from './utils.js';
/**
* Sampler is a base class for all sampling methods used for text generation.
*/
class Sampler extends Callable {
export class Sampler extends Callable {
/**
* Creates a new Sampler object with the specified temperature.
* @param {number} temperature - The temperature to use when sampling. Higher values result in more random samples.
@ -240,10 +240,3 @@ class BeamSearchSampler extends Sampler {
}
}
}
module.exports = {
Sampler,
GreedySampler,
TopKSampler,
BeamSearchSampler
}

View File

@ -1,6 +1,9 @@
const { ONNX } = require('./backends/onnx.js');
import { ONNX } from './backends/onnx.js';
const { interpolate: interpolate_data, transpose: transpose_data } = require('./math_utils.js');
import {
interpolate as interpolate_data,
transpose_data
} from './math_utils.js';
/**
@ -10,7 +13,7 @@ const { interpolate: interpolate_data, transpose: transpose_data } = require('./
const ONNXTensor = ONNX.Tensor;
// TODO: fix error below
class Tensor extends ONNXTensor {
export class Tensor extends ONNXTensor {
/**
* Create a new Tensor or copy an existing Tensor.
* @param {[string, Array|AnyTypedArray, number[]]|[ONNXTensor]} args
@ -190,7 +193,7 @@ function reshape(data, dimensions) {
* @param {Array} axes - The axes to transpose the tensor along.
* @returns {Tensor} The transposed tensor.
*/
function transpose(tensor, axes) {
export function transpose(tensor, axes) {
const [transposedData, shape] = transpose_data(tensor.data, tensor.dims, axes);
return new Tensor(tensor.type, transposedData, shape);
}
@ -202,7 +205,7 @@ function transpose(tensor, axes) {
* @param {any} tensors - The array of tensors to concatenate.
* @returns {Tensor} - The concatenated tensor.
*/
function cat(tensors) {
export function cat(tensors) {
if (tensors.length === 0) {
return tensors[0];
}
@ -241,7 +244,7 @@ function cat(tensors) {
* @param {boolean} align_corners - Whether to align corners.
* @returns {Tensor} - The interpolated tensor.
*/
function interpolate(input, [out_height, out_width], mode = 'bilinear', align_corners = false) {
export function interpolate(input, [out_height, out_width], mode = 'bilinear', align_corners = false) {
// Input image dimensions
const in_channels = input.dims.at(-3) ?? 1;
@ -257,11 +260,3 @@ function interpolate(input, [out_height, out_width], mode = 'bilinear', align_co
);
return new Tensor(input.type, output, [in_channels, out_height, out_width]);
}
module.exports = {
Tensor,
transpose,
cat,
interpolate,
transpose_data,
}

View File

@ -1,13 +1,13 @@
const {
import {
Callable,
fetchJSON,
reverseDictionary,
escapeRegExp,
isIntegralNumber,
min,
} = require('./utils.js');
} from './utils.js';
const { Tensor } = require('./tensor_utils.js')
import { Tensor } from './tensor_utils.js';
/**
* Abstract base class for tokenizer models.
@ -1837,7 +1837,7 @@ function bert_prepare_model_inputs(inputs) {
* BertTokenizer is a class used to tokenize text for BERT models.
* @extends PreTrainedTokenizer
*/
class BertTokenizer extends PreTrainedTokenizer {
export class BertTokenizer extends PreTrainedTokenizer {
/**
* @see {@link bert_prepare_model_inputs}
*/
@ -1849,7 +1849,7 @@ class BertTokenizer extends PreTrainedTokenizer {
* Albert tokenizer
* @extends PreTrainedTokenizer
*/
class AlbertTokenizer extends PreTrainedTokenizer {
export class AlbertTokenizer extends PreTrainedTokenizer {
/**
* @see {@link bert_prepare_model_inputs}
*/
@ -1857,7 +1857,7 @@ class AlbertTokenizer extends PreTrainedTokenizer {
return bert_prepare_model_inputs(inputs);
}
}
class MobileBertTokenizer extends PreTrainedTokenizer {
export class MobileBertTokenizer extends PreTrainedTokenizer {
/**
* @see {@link bert_prepare_model_inputs}
*/
@ -1865,7 +1865,7 @@ class MobileBertTokenizer extends PreTrainedTokenizer {
return bert_prepare_model_inputs(inputs);
}
}
class SqueezeBertTokenizer extends PreTrainedTokenizer {
export class SqueezeBertTokenizer extends PreTrainedTokenizer {
/**
* @see {@link bert_prepare_model_inputs}
*/
@ -1873,18 +1873,18 @@ class SqueezeBertTokenizer extends PreTrainedTokenizer {
return bert_prepare_model_inputs(inputs);
}
}
class DistilBertTokenizer extends PreTrainedTokenizer { }
class T5Tokenizer extends PreTrainedTokenizer { }
class GPT2Tokenizer extends PreTrainedTokenizer { }
class BartTokenizer extends PreTrainedTokenizer { }
class RobertaTokenizer extends PreTrainedTokenizer { }
export class DistilBertTokenizer extends PreTrainedTokenizer { }
export class T5Tokenizer extends PreTrainedTokenizer { }
export class GPT2Tokenizer extends PreTrainedTokenizer { }
export class BartTokenizer extends PreTrainedTokenizer { }
export class RobertaTokenizer extends PreTrainedTokenizer { }
/**
* WhisperTokenizer tokenizer
* @extends PreTrainedTokenizer
*/
class WhisperTokenizer extends PreTrainedTokenizer {
export class WhisperTokenizer extends PreTrainedTokenizer {
static LANGUAGES = {
"en": "english",
"zh": "chinese",
@ -2297,9 +2297,9 @@ class WhisperTokenizer extends PreTrainedTokenizer {
return totalSequence;
}
}
class CodeGenTokenizer extends PreTrainedTokenizer { }
class CLIPTokenizer extends PreTrainedTokenizer { }
class MarianTokenizer extends PreTrainedTokenizer {
export class CodeGenTokenizer extends PreTrainedTokenizer { }
export class CLIPTokenizer extends PreTrainedTokenizer { }
export class MarianTokenizer extends PreTrainedTokenizer {
/**
* Create a new MarianTokenizer instance.
* @param {Object} tokenizerJSON - The JSON of the tokenizer.
@ -2568,7 +2568,7 @@ class TokenLatticeNode {
}
}
class AutoTokenizer {
export class AutoTokenizer {
// Helper class to determine tokenizer type from tokenizer.json
static TOKENIZER_CLASS_MAPPING = {
'T5Tokenizer': T5Tokenizer,
@ -2601,11 +2601,3 @@ class AutoTokenizer {
return new cls(tokenizerJSON, tokenizerConfig);
}
}
module.exports = {
AutoTokenizer,
BertTokenizer,
DistilBertTokenizer,
T5Tokenizer,
GPT2Tokenizer
};

View File

@ -1,12 +1,15 @@
const {
// Tokenizers
export {
AutoTokenizer,
BertTokenizer,
DistilBertTokenizer,
T5Tokenizer,
GPT2Tokenizer
} = require("./tokenizers.js");
const {
} from './tokenizers.js';
// Models
export {
AutoModel,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
@ -17,54 +20,18 @@ const {
AutoModelForVision2Seq,
AutoModelForImageClassification,
AutoModelForObjectDetection,
} = require("./models.js");
} from './models.js';
const {
// Processors
export {
AutoProcessor
} = require("./processors.js");
const {
} from './processors.js';
// environment variables
export { env } from './env.js';
// other
export {
pipeline
} = require("./pipelines.js");
const { env } = require('./env.js');
const { Tensor } = require('./tensor_utils.js');
const moduleExports = {
// Tokenizers
AutoTokenizer,
BertTokenizer,
DistilBertTokenizer,
T5Tokenizer,
GPT2Tokenizer,
// Models
AutoModel,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForVision2Seq,
AutoModelForImageClassification,
AutoModelForObjectDetection,
// Processors
AutoProcessor,
// other
pipeline,
Tensor,
// environment variables
env
};
// Allow global access to these variables
if (typeof self !== 'undefined') {
// Used by web workers
Object.assign(self, moduleExports);
}
// Used by other modules
module.exports = moduleExports
} from './pipelines.js';
export { Tensor } from './tensor_utils.js';

View File

@ -1,12 +1,12 @@
const fs = require('fs');
import { existsSync, statSync, promises } from 'fs';
const { env } = require('./env.js');
import { env } from './env.js';
if (global.ReadableStream === undefined && typeof process !== 'undefined') {
try {
// @ts-ignore
global.ReadableStream = require('node:stream/web').ReadableStream; // ReadableStream is not a global with Node 16
global.ReadableStream = (await import('node:stream/web')).ReadableStream; // ReadableStream is not a global with Node 16
} catch (err) {
console.warn("ReadableStream not defined and unable to import from node:stream/web");
}
@ -22,12 +22,12 @@ class FileResponse {
this.headers = {};
this.headers.get = (x) => this.headers[x]
this.exists = fs.existsSync(filePath);
this.exists = existsSync(filePath);
if (this.exists) {
this.status = 200;
this.statusText = 'OK';
let stats = fs.statSync(filePath);
let stats = statSync(filePath);
this.headers['content-length'] = stats.size;
this.updateContentType();
@ -111,7 +111,7 @@ class FileResponse {
* @throws {Error} - If the file cannot be read.
*/
async arrayBuffer() {
const data = await fs.promises.readFile(this.filePath);
const data = await promises.readFile(this.filePath);
return data.buffer;
}
@ -124,7 +124,7 @@ class FileResponse {
* @throws {Error} - If the file cannot be read.
*/
async blob() {
const data = await fs.promises.readFile(this.filePath);
const data = await promises.readFile(this.filePath);
return new Blob([data], { type: this.headers['content-type'] });
}
@ -137,7 +137,7 @@ class FileResponse {
* @throws {Error} - If the file cannot be read.
*/
async text() {
const data = await fs.promises.readFile(this.filePath, 'utf8');
const data = await promises.readFile(this.filePath, 'utf8');
return data;
}
@ -179,7 +179,7 @@ function isValidHttpUrl(string) {
* @param {string|URL} url - The URL of the file to get.
* @returns {Promise<FileResponse|Response>} A promise that resolves to a FileResponse object (if the file is retrieved using the FileSystem API), or a Response object (if the file is retrieved using the Fetch API).
*/
async function getFile(url) {
export async function getFile(url) {
// Helper function to get a file, using either the Fetch API or FileSystem API
if (env.useFS && !isValidHttpUrl(url)) {
@ -198,7 +198,7 @@ async function getFile(url) {
* @param {any} data - The data to pass to the progress callback function.
* @returns {void}
*/
function dispatchCallback(progressCallback, data) {
export function dispatchCallback(progressCallback, data) {
if (progressCallback !== null) progressCallback(data);
}
@ -293,7 +293,7 @@ async function getModelFile(modelPath, fileName, progressCallback = null, fatal
* @param {function} progressCallback - A callback function to receive progress updates. Optional.
* @returns {Promise<object>} - The JSON data parsed into a JavaScript object.
*/
async function fetchJSON(modelPath, fileName, progressCallback = null, fatal = true) {
export async function fetchJSON(modelPath, fileName, progressCallback = null, fatal = true) {
let buffer = await getModelFile(modelPath, fileName, progressCallback, fatal);
if (buffer === null) {
// Return empty object
@ -369,7 +369,7 @@ async function readResponse(response, progressCallback) {
* @param {...string} parts - Multiple parts of a path.
* @returns {string} A string representing the joined path.
*/
function pathJoin(...parts) {
export function pathJoin(...parts) {
// https://stackoverflow.com/a/55142565
parts = parts.map((part, index) => {
if (index) {
@ -390,7 +390,7 @@ function pathJoin(...parts) {
* @returns {object} The reversed object.
* @see https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript
*/
function reverseDictionary(data) {
export function reverseDictionary(data) {
// https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript
return Object.fromEntries(Object.entries(data).map(([key, value]) => [value, key]));
}
@ -401,7 +401,7 @@ function reverseDictionary(data) {
* @see https://stackoverflow.com/a/11301464
* @returns {number} - The index of the maximum value in the array.
*/
function indexOfMax(arr) {
export function indexOfMax(arr) {
// https://stackoverflow.com/a/11301464
if (arr.length === 0) {
@ -427,7 +427,7 @@ function indexOfMax(arr) {
* @param {number[]} arr - The array of numbers to compute the softmax of.
* @returns {number[]} The softmax array.
*/
function softmax(arr) {
export function softmax(arr) {
// Compute the maximum value in the array
const maxVal = max(arr);
@ -448,7 +448,7 @@ function softmax(arr) {
* @param {number[]} arr - The input array to calculate the log_softmax function for.
* @returns {any} - The resulting log_softmax array.
*/
function log_softmax(arr) {
export function log_softmax(arr) {
// Compute the softmax values
const softmaxArr = softmax(arr);
@ -464,7 +464,7 @@ function log_softmax(arr) {
* @param {string} string - The string to escape.
* @returns {string} - The escaped string.
*/
function escapeRegExp(string) {
export function escapeRegExp(string) {
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
}
@ -475,7 +475,7 @@ function escapeRegExp(string) {
* @param {number} [top_k=0] - The number of top items to return (default: 0 = return all)
* @returns {Array} - The top k items, sorted by descending order
*/
function getTopItems(items, top_k = 0) {
export function getTopItems(items, top_k = 0) {
// if top == 0, return all
items = Array.from(items)
@ -495,7 +495,7 @@ function getTopItems(items, top_k = 0) {
* @param {number[]} arr2 - The second array.
* @returns {number} - The dot product of arr1 and arr2.
*/
function dot(arr1, arr2) {
export function dot(arr1, arr2) {
return arr1.reduce((acc, val, i) => acc + val * arr2[i], 0);
}
@ -506,7 +506,7 @@ function dot(arr1, arr2) {
* @param {number[]} arr2 - The second array.
* @returns {number} The cosine similarity between the two arrays.
*/
function cos_sim(arr1, arr2) {
export function cos_sim(arr1, arr2) {
// Calculate dot product of the two arrays
const dotProduct = dot(arr1, arr2);
@ -527,7 +527,7 @@ function cos_sim(arr1, arr2) {
* @param {number[]} arr - The array to calculate the magnitude of.
* @returns {number} The magnitude of the array.
*/
function magnitude(arr) {
export function magnitude(arr) {
return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0));
}
@ -536,7 +536,7 @@ function magnitude(arr) {
*
* @extends Function
*/
class Callable extends Function {
export class Callable extends Function {
/**
* Creates a new instance of the Callable class.
*/
@ -573,7 +573,7 @@ class Callable extends Function {
* @returns {number} - the minimum number.
* @throws {Error} If array is empty.
*/
function min(arr) {
export function min(arr) {
if (arr.length === 0) throw Error('Array must not be empty');
let min = arr[0];
for (let i = 1; i < arr.length; ++i) {
@ -591,7 +591,7 @@ function min(arr) {
* @returns {number} - the maximum number.
* @throws {Error} If array is empty.
*/
function max(arr) {
export function max(arr) {
if (arr.length === 0) throw Error('Array must not be empty');
let max = arr[0];
for (let i = 1; i < arr.length; ++i) {
@ -607,7 +607,7 @@ function max(arr) {
* @param {*} text - The value to check.
* @returns {boolean} - True if the value is a string, false otherwise.
*/
function isString(text) {
export function isString(text) {
return typeof text === 'string' || text instanceof String
}
@ -616,7 +616,7 @@ function isString(text) {
* @param {*} x - The value to check.
* @returns {boolean} - True if the value is a string, false otherwise.
*/
function isIntegralNumber(x) {
export function isIntegralNumber(x) {
return Number.isInteger(x) || typeof x === 'bigint'
}
@ -625,29 +625,10 @@ function isIntegralNumber(x) {
* @param {*} x - The value to check.
* @returns {boolean} - True if the value exists, false otherwise.
*/
function exists(x) {
export function exists(x) {
return x !== undefined && x !== null;
}
module.exports = {
Callable,
export {
getModelFile,
dispatchCallback,
fetchJSON,
pathJoin,
reverseDictionary,
indexOfMax,
softmax,
log_softmax,
escapeRegExp,
getTopItems,
dot,
cos_sim,
magnitude,
getFile,
isIntegralNumber,
isString,
exists,
min,
max,
};

View File

@ -1,6 +1,8 @@
const path = require('path');
const { pipeline, env } = require('..');
import path from 'path';
import { pipeline, env } from '../src/transformers.js';
const __dirname = env.__dirname;
// Only use local models
env.remoteModels = false;
@ -757,7 +759,7 @@ async function image_classification() {
async function image_segmentation() {
let segmenter = await pipeline('image-segmentation', 'facebook/detr-resnet-50-panoptic')
let img = path.join(__dirname, '../assets/images/cats.jpg')
let img = path.join(__dirname, './assets/images/cats.jpg')
let start = performance.now();
let outputs = await segmenter(img);