Reorganize generate method

This commit is contained in:
Joshua Lochner 2023-02-16 20:20:17 +02:00
parent d2b4fadb8e
commit 6e71ee125a
1 changed files with 114 additions and 87 deletions

View File

@ -1,6 +1,14 @@
import { Callable, fetchJSON, pathJoin, indexOfMax, softmax } from "./utils.js";
import {
Callable,
fetchJSON,
pathJoin,
indexOfMax,
softmax
} from "./utils.js";
//////////////////////////////////////////////////
// Helper functions
async function constructSession(path) {
let response = await fetch(path, {
cache: 'force-cache'
@ -11,6 +19,8 @@ async function constructSession(path) {
});
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// AutoModels, used to simplify construction of PreTrainedModels
// (uses config to instantiate correct class)
@ -83,6 +93,7 @@ class AutoModelForSeq2SeqLM {
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Base class
class PreTrainedModel extends Callable {
constructor(config, encoder_session) {
@ -126,6 +137,83 @@ class PreTrainedModel extends Callable {
throw Error("forward should be implemented in subclasses.")
}
async generate(inputTokenIds, maxLength = 100, topK = 0, topP = 0, numBeams = 0) {
let attentionMask = new Array(inputTokenIds.length).fill(1);
let encoderOutputs = null;
let pastKeyValues = null;
let outputTokenIds = [this.config.decoder_start_token_id];
let numOutputTokens = 1;
const maxOutputTokens = numOutputTokens + maxLength;
let sampler = x => this.sampleLogitsGreedily(x);
if (topK > 0) {
sampler = x => this.sampleLogitsTopK(x, topK);
}
while (numOutputTokens < maxOutputTokens) {
let output = await this.forward({
input_ids: inputTokenIds,
attention_mask: attentionMask,
decoder_input_ids: outputTokenIds,
encoder_outputs: encoderOutputs,
past_key_values: pastKeyValues,
});
pastKeyValues = output.pastKeyValues;
encoderOutputs = output.encoderOutputs;
let newTokenId = sampler(output.logits);
outputTokenIds.push(newTokenId);
++numOutputTokens;
if (newTokenId === this.config.eos_token_id) {
break;
}
}
return outputTokenIds;
}
sampleLogitsGreedily(logits) {
let shape = logits.dims;
let [batchSize, seqLength, vocabSize] = shape;
let n = batchSize * seqLength * vocabSize;
let startIndex = n - vocabSize;
let argmaxi = 0;
let argmax = logits.data[startIndex + argmaxi];
for (let i = 1; i < vocabSize; i++) {
let l = logits.data[startIndex + i];
if (l > argmax) {
argmaxi = i;
argmax = l;
}
}
return argmaxi;
}
sampleLogitsTopK(logits, k) {
let shape = logits.dims;
let [batchSize, seqLength, vocabSize] = shape;
let n = batchSize * seqLength * vocabSize;
let startIndex = n - vocabSize;
let logs = logits.data.slice(startIndex);
k = Math.min(k, vocabSize);
let logitAndId = Array.from(logs).map((x, i) => [x, i])
.sort((a, b) => b[0] - a[0]);
const sMin = Math.exp(-100.0);
let sumS = 0.0;
for (let i = 0; i < logitAndId.length; i++) {
const s = i < k ? Math.exp(logitAndId[i][0]) : sMin;
sumS += s;
logitAndId[i][0] = s;
}
let r = Math.random() * sumS;
for (let i = 0; i < logitAndId.length; i++) {
r -= logitAndId[i][0];
if (r <= 0) {
return logitAndId[i][1];
}
}
return logitAndId[0][1];
}
}
//////////////////////////////////////////////////
@ -140,16 +228,10 @@ class DistilBertModel extends DistilBertPreTrainedModel { }
class DistilBertForSequenceClassification extends DistilBertPreTrainedModel {
async _call(model_inputs) {
let logits = (await super._call(model_inputs)).logits.data;
let predictionIndex = indexOfMax(logits);
let score = softmax(logits)[predictionIndex];
return {
logits: logits,
prediction: this.config.id2label[predictionIndex],
score: score
};
return new ClassificationOutput(this.config, logits)
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@ -237,88 +319,24 @@ class T5ForConditionalGeneration extends T5PreTrainedModel {
}
return pkvs;
}
}
//////////////////////////////////////////////////
async generate(inputTokenIds, maxLength = 100, topK = 0, topP = 0, numBeams = 0) {
let attentionMask = new Array(inputTokenIds.length).fill(1);
let encoderOutputs = null;
let pastKeyValues = null;
let outputTokenIds = [this.config.decoder_start_token_id];
let numOutputTokens = 1;
const maxOutputTokens = numOutputTokens + maxLength;
let sampler = x => this.sampleLogitsGreedily(x);
if (topK > 0) {
sampler = x => this.sampleLogitsTopK(x, topK);
}
while (numOutputTokens < maxOutputTokens) {
let output = await this.forward({
input_ids: inputTokenIds,
attention_mask: attentionMask,
decoder_input_ids: outputTokenIds,
encoder_outputs: encoderOutputs,
past_key_values: pastKeyValues,
});
pastKeyValues = output.pastKeyValues;
encoderOutputs = output.encoderOutputs;
let newTokenId = sampler(output.logits);
outputTokenIds.push(newTokenId);
++numOutputTokens;
if (newTokenId === this.config.eos_token_id) {
break;
}
}
return outputTokenIds;
}
sampleLogitsGreedily(logits) {
let shape = logits.dims;
let [batchSize, seqLength, vocabSize] = shape;
let n = batchSize * seqLength * vocabSize;
let startIndex = n - vocabSize;
let argmaxi = 0;
let argmax = logits.data[startIndex + argmaxi];
for (let i = 1; i < vocabSize; i++) {
let l = logits.data[startIndex + i];
if (l > argmax) {
argmaxi = i;
argmax = l;
}
}
return argmaxi;
}
sampleLogitsTopK(logits, k) {
let shape = logits.dims;
let [batchSize, seqLength, vocabSize] = shape;
let n = batchSize * seqLength * vocabSize;
let startIndex = n - vocabSize;
let logs = logits.data.slice(startIndex);
k = Math.min(k, vocabSize);
let logitAndId = Array.from(logs).map((x, i) => [x, i])
.sort((a, b) => b[0] - a[0]);
const sMin = Math.exp(-100.0);
let sumS = 0.0;
for (let i = 0; i < logitAndId.length; i++) {
const s = i < k ? Math.exp(logitAndId[i][0]) : sMin;
sumS += s;
logitAndId[i][0] = s;
}
let r = Math.random() * sumS;
for (let i = 0; i < logitAndId.length; i++) {
r -= logitAndId[i][0];
if (r <= 0) {
return logitAndId[i][1];
}
}
return logitAndId[0][1];
//////////////////////////////////////////////////
// GPT2 models
class GPT2PreTrainedModel extends PreTrainedModel { }
class GPT2Model extends GPT2PreTrainedModel {
async generate(...args) {
throw Error(
"The current model class (GPT2Model) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'T5ForConditionalGeneration'}"
)
}
}
class GPT2LMHeadModel extends GPT2PreTrainedModel { }
// class GPT2ForSequenceClassification extends GPT2PreTrainedModel {
// TODO
// }
//////////////////////////////////////////////////
class Seq2SeqLMOutput {
@ -329,6 +347,15 @@ class Seq2SeqLMOutput {
}
}
class ClassificationOutput {
constructor(modelConfig, logits) {
this.logits = logits;
this.prediction = indexOfMax(logits);
this.score = softmax(logits)[this.prediction];
this.label = modelConfig.id2label[this.prediction];
}
}
export {
AutoModel,
AutoModelForSeq2SeqLM,