Add `T5ForConditionalGeneration` model

This commit is contained in:
Joshua Lochner 2023-02-16 00:14:20 +02:00
parent e0a88f9453
commit 36ec1431f9
2 changed files with 256 additions and 21 deletions

View File

@ -16,12 +16,18 @@ class AutoModel {
static async from_pretrained(modelPath) {
let config = await fetchJSON(pathJoin(modelPath, 'config.json'));
let session = await constructSession(pathJoin(modelPath, 'model.onnx'));
let [config, session] = await Promise.all([
fetchJSON(pathJoin(modelPath, 'config.json')),
constructSession(pathJoin(modelPath, 'model.onnx'))
]);
switch (config.model_type) {
case 'bert':
return new BertModel(config, session);
case 'distilbert':
return new DistilBertModel(config, session);
case 't5':
return new T5Model(config, session);
default:
console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`);
@ -30,44 +36,256 @@ class AutoModel {
}
}
class AutoModelForSeq2SeqLM {
static async from_pretrained(modelPath) {
let [config, encoder_session, decoder_session, init_decoder_session] = await Promise.all([
fetchJSON(pathJoin(modelPath, 'config.json')),
constructSession(pathJoin(modelPath, 'encoder_model.onnx')),
constructSession(pathJoin(modelPath, 'decoder_model.onnx')),
constructSession(pathJoin(modelPath, 'decoder_with_past_model.onnx'))
])
switch (config.model_type) {
case 't5':
return new T5ForConditionalGeneration(
config,
encoder_session,
decoder_session,
init_decoder_session
);
default:
throw Error(`Unsupported model type: ${config.model_type}`)
}
}
}
class PreTrainedModel extends Callable {
constructor(config, session) {
constructor(config, encoder_session) {
super();
this.config = config;
this.session = session;
this.encoder_session = encoder_session;
}
static async from_pretrained(modelPath) {
// Load model
let config = await fetchJSON(pathJoin(modelPath, 'config.json'));
let session = await constructSession(pathJoin(modelPath, 'model.onnx'));
let [config, session] = await Promise.all([
fetchJSON(pathJoin(modelPath, 'config.json')),
constructSession(pathJoin(modelPath, 'model.onnx'))
]);
return new this(config, session);
}
async _call(model_input) {
// TODO allow batched inputs
for (let [key, value] of Object.entries(model_input)) {
if (Array.isArray(value)) {
// convert arrays to tensor
model_input[key] = new ort.Tensor('int64',
prepare_inputs(model_inputs) {
// TODO improve
for (let [key, value] of Object.entries(model_inputs)) {
if (Array.isArray(value) && value && Number.isInteger(value[0])) {
// convert integer arrays to tensor
model_inputs[key] = new ort.Tensor('int64',
BigInt64Array.from(value.map(x => BigInt(x))),
[1, value.length]
);
}
}
return model_inputs;
}
async _call(model_inputs) {
// TODO allow batched inputs
return await this.session.run(model_input);
model_inputs = this.prepare_inputs(model_inputs)
return await this.encoder_session.run(model_inputs);
}
async forward(model_inputs) {
throw Error("forward should be implemented in subclasses.")
}
}
class DistilBertModel extends PreTrainedModel {
class BertModel extends PreTrainedModel { }
class DistilBertModel extends PreTrainedModel { }
class T5PreTrainedModel extends PreTrainedModel {
};
class T5Model extends T5PreTrainedModel {
async generate(...args) {
throw Error(
"The current model class (T5Model) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'T5ForConditionalGeneration'}"
)
}
}
export { AutoModel };
class T5ForConditionalGeneration extends T5PreTrainedModel {
constructor(config, encoder_session, decoder_session, init_decoder_session) {
super(config, encoder_session);
this.decoder_session = decoder_session;
this.init_decoder_session = init_decoder_session;
}
static async from_pretrained(modelPath) {
// TODO optimize? Lots of overlap between decoder and init_decoder
let [config, encoder_session, decoder_session, init_decoder_session] = await Promise.all([
fetchJSON(pathJoin(modelPath, 'config.json')),
constructSession(pathJoin(modelPath, 'encoder_model.onnx')),
constructSession(pathJoin(modelPath, 'decoder_model.onnx')),
constructSession(pathJoin(modelPath, 'decoder_with_past_model.onnx'))
])
return new this(config, encoder_session, decoder_session, init_decoder_session);
}
async forward(model_inputs) {
model_inputs = this.prepare_inputs(model_inputs)
let inputIdsTensor = model_inputs.input_ids;
let encoderAttentionMaskTensor = model_inputs.attention_mask;
let decoderInputIdsTensor = model_inputs.decoder_input_ids;
let encoderOutputs = model_inputs.encoder_outputs;
let pastKeyValues = model_inputs.past_key_values;
if (encoderOutputs === null) {
const encoderFeeds = {
"input_ids": inputIdsTensor,
"attention_mask": encoderAttentionMaskTensor,
}
const encoderResults = await this.encoder_session.run(encoderFeeds);
const encoderHiddenStates = encoderResults.hidden_states;
encoderOutputs = encoderHiddenStates;
}
// const decoderAttentionMaskTensor = new ort.Tensor("int64", new BigInt64Array(decoderInputIds.length).fill(1n), [1, decoderInputIds.length]);
const decoderFeeds = {
"input_ids": decoderInputIdsTensor,
"encoder_attention_mask": encoderAttentionMaskTensor,
"encoder_hidden_states": encoderOutputs,
};
let logits = null;
if (pastKeyValues === null) {
const initDecoderResults = await this.init_decoder_session.run(decoderFeeds);
logits = initDecoderResults.logits;
pastKeyValues = this.getPastKeyValues(this.init_decoder_session.outputNames.slice(1), initDecoderResults);
} else {
for (const [k, v] of pastKeyValues) {
decoderFeeds[k] = v;
}
const decoderResults = await this.decoder_session.run(decoderFeeds);
logits = decoderResults.logits;
pastKeyValues = this.getPastKeyValues(this.decoder_session.outputNames.slice(1), decoderResults);
}
return new Seq2SeqLMOutput(logits, pastKeyValues, encoderOutputs);
}
getPastKeyValues(pkvNames, decoderResults) {
const pkvs = [];
for (const i in pkvNames) {
const k = pkvNames[i];
const v = decoderResults[k];
pkvs.push([`pkv_${i}`, v]);
}
return pkvs;
}
async generate(inputTokenIds, maxLength = 100, topK = 0, topP = 0, numBeams = 0) {
let attentionMask = new Array(inputTokenIds.length).fill(1);
let encoderOutputs = null;
let pastKeyValues = null;
let outputTokenIds = [this.config.decoder_start_token_id];
let numOutputTokens = 1;
const maxOutputTokens = numOutputTokens + maxLength;
let sampler = x => this.sampleLogitsGreedily(x);
if (topK > 0) {
sampler = x => this.sampleLogitsTopK(x, topK);
}
while (numOutputTokens < maxOutputTokens) {
let output = await this.forward({
input_ids: inputTokenIds,
attention_mask: attentionMask,
decoder_input_ids: outputTokenIds,
encoder_outputs: encoderOutputs,
past_key_values: pastKeyValues,
});
pastKeyValues = output.pastKeyValues;
encoderOutputs = output.encoderOutputs;
let newTokenId = sampler(output.logits);
outputTokenIds.push(newTokenId);
++numOutputTokens;
if (newTokenId === this.config.eos_token_id) {
break;
}
}
return outputTokenIds;
}
sampleLogitsGreedily(logits) {
let shape = logits.dims;
let [batchSize, seqLength, vocabSize] = shape;
let n = batchSize * seqLength * vocabSize;
let startIndex = n - vocabSize;
let argmaxi = 0;
let argmax = logits.data[startIndex + argmaxi];
for (let i = 1; i < vocabSize; i++) {
let l = logits.data[startIndex + i];
if (l > argmax) {
argmaxi = i;
argmax = l;
}
}
return argmaxi;
}
sampleLogitsTopK(logits, k) {
let shape = logits.dims;
let [batchSize, seqLength, vocabSize] = shape;
let n = batchSize * seqLength * vocabSize;
let startIndex = n - vocabSize;
let logs = logits.data.slice(startIndex);
k = Math.min(k, vocabSize);
let logitAndId = Array.from(logs).map((x, i) => [x, i])
.sort((a, b) => b[0] - a[0]);
const sMin = Math.exp(-100.0);
let sumS = 0.0;
for (let i = 0; i < logitAndId.length; i++) {
const s = i < k ? Math.exp(logitAndId[i][0]) : sMin;
sumS += s;
logitAndId[i][0] = s;
}
let r = Math.random() * sumS;
for (let i = 0; i < logitAndId.length; i++) {
r -= logitAndId[i][0];
if (r <= 0) {
return logitAndId[i][1];
}
}
return logitAndId[0][1];
}
}
class Seq2SeqLMOutput {
constructor(logits, pastKeyValues, encoderOutputs) {
this.logits = logits;
this.pastKeyValues = pastKeyValues;
this.encoderOutputs = encoderOutputs;
}
}
export {
AutoModel,
AutoModelForSeq2SeqLM,
T5ForConditionalGeneration
};

View File

@ -1,13 +1,30 @@
import { AutoTokenizer, T5Tokenizer } from "./tokenizers.js";
import { AutoModel } from "./models.js";
import {
AutoTokenizer,
T5Tokenizer
} from "./tokenizers.js";
import {
AutoModel,
AutoModelForSeq2SeqLM,
T5ForConditionalGeneration
} from "./models.js";
// Allow global access to these variables
window.AutoTokenizer = AutoTokenizer
window.T5Tokenizer = T5Tokenizer
window.AutoModel = AutoModel
window.AutoModelForSeq2SeqLM = AutoModelForSeq2SeqLM
window.T5ForConditionalGeneration = T5ForConditionalGeneration
export { AutoTokenizer, T5Tokenizer, AutoModel };
export {
// Tokenizers
AutoTokenizer,
T5Tokenizer,
// Models
AutoModel,
AutoModelForSeq2SeqLM,
T5ForConditionalGeneration
};