Reorganize generate method
This commit is contained in:
parent
d2b4fadb8e
commit
6e71ee125a
201
src/models.js
201
src/models.js
|
@ -1,6 +1,14 @@
|
|||
import { Callable, fetchJSON, pathJoin, indexOfMax, softmax } from "./utils.js";
|
||||
import {
|
||||
Callable,
|
||||
fetchJSON,
|
||||
pathJoin,
|
||||
indexOfMax,
|
||||
softmax
|
||||
} from "./utils.js";
|
||||
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
// Helper functions
|
||||
async function constructSession(path) {
|
||||
let response = await fetch(path, {
|
||||
cache: 'force-cache'
|
||||
|
@ -11,6 +19,8 @@ async function constructSession(path) {
|
|||
});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
// AutoModels, used to simplify construction of PreTrainedModels
|
||||
// (uses config to instantiate correct class)
|
||||
|
@ -83,6 +93,7 @@ class AutoModelForSeq2SeqLM {
|
|||
}
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
// Base class
|
||||
class PreTrainedModel extends Callable {
|
||||
constructor(config, encoder_session) {
|
||||
|
@ -126,6 +137,83 @@ class PreTrainedModel extends Callable {
|
|||
throw Error("forward should be implemented in subclasses.")
|
||||
}
|
||||
|
||||
async generate(inputTokenIds, maxLength = 100, topK = 0, topP = 0, numBeams = 0) {
|
||||
|
||||
let attentionMask = new Array(inputTokenIds.length).fill(1);
|
||||
|
||||
let encoderOutputs = null;
|
||||
let pastKeyValues = null;
|
||||
let outputTokenIds = [this.config.decoder_start_token_id];
|
||||
let numOutputTokens = 1;
|
||||
const maxOutputTokens = numOutputTokens + maxLength;
|
||||
|
||||
let sampler = x => this.sampleLogitsGreedily(x);
|
||||
if (topK > 0) {
|
||||
sampler = x => this.sampleLogitsTopK(x, topK);
|
||||
}
|
||||
|
||||
while (numOutputTokens < maxOutputTokens) {
|
||||
|
||||
let output = await this.forward({
|
||||
input_ids: inputTokenIds,
|
||||
attention_mask: attentionMask,
|
||||
decoder_input_ids: outputTokenIds,
|
||||
encoder_outputs: encoderOutputs,
|
||||
past_key_values: pastKeyValues,
|
||||
});
|
||||
pastKeyValues = output.pastKeyValues;
|
||||
encoderOutputs = output.encoderOutputs;
|
||||
|
||||
let newTokenId = sampler(output.logits);
|
||||
outputTokenIds.push(newTokenId);
|
||||
++numOutputTokens;
|
||||
if (newTokenId === this.config.eos_token_id) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return outputTokenIds;
|
||||
}
|
||||
sampleLogitsGreedily(logits) {
|
||||
let shape = logits.dims;
|
||||
let [batchSize, seqLength, vocabSize] = shape;
|
||||
let n = batchSize * seqLength * vocabSize;
|
||||
let startIndex = n - vocabSize;
|
||||
let argmaxi = 0;
|
||||
let argmax = logits.data[startIndex + argmaxi];
|
||||
for (let i = 1; i < vocabSize; i++) {
|
||||
let l = logits.data[startIndex + i];
|
||||
if (l > argmax) {
|
||||
argmaxi = i;
|
||||
argmax = l;
|
||||
}
|
||||
}
|
||||
return argmaxi;
|
||||
}
|
||||
sampleLogitsTopK(logits, k) {
|
||||
let shape = logits.dims;
|
||||
let [batchSize, seqLength, vocabSize] = shape;
|
||||
let n = batchSize * seqLength * vocabSize;
|
||||
let startIndex = n - vocabSize;
|
||||
let logs = logits.data.slice(startIndex);
|
||||
k = Math.min(k, vocabSize);
|
||||
let logitAndId = Array.from(logs).map((x, i) => [x, i])
|
||||
.sort((a, b) => b[0] - a[0]);
|
||||
const sMin = Math.exp(-100.0);
|
||||
let sumS = 0.0;
|
||||
for (let i = 0; i < logitAndId.length; i++) {
|
||||
const s = i < k ? Math.exp(logitAndId[i][0]) : sMin;
|
||||
sumS += s;
|
||||
logitAndId[i][0] = s;
|
||||
}
|
||||
let r = Math.random() * sumS;
|
||||
for (let i = 0; i < logitAndId.length; i++) {
|
||||
r -= logitAndId[i][0];
|
||||
if (r <= 0) {
|
||||
return logitAndId[i][1];
|
||||
}
|
||||
}
|
||||
return logitAndId[0][1];
|
||||
}
|
||||
}
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
|
@ -140,16 +228,10 @@ class DistilBertModel extends DistilBertPreTrainedModel { }
|
|||
class DistilBertForSequenceClassification extends DistilBertPreTrainedModel {
|
||||
async _call(model_inputs) {
|
||||
let logits = (await super._call(model_inputs)).logits.data;
|
||||
let predictionIndex = indexOfMax(logits);
|
||||
let score = softmax(logits)[predictionIndex];
|
||||
|
||||
return {
|
||||
logits: logits,
|
||||
prediction: this.config.id2label[predictionIndex],
|
||||
score: score
|
||||
};
|
||||
return new ClassificationOutput(this.config, logits)
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
|
@ -237,88 +319,24 @@ class T5ForConditionalGeneration extends T5PreTrainedModel {
|
|||
}
|
||||
return pkvs;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
async generate(inputTokenIds, maxLength = 100, topK = 0, topP = 0, numBeams = 0) {
|
||||
|
||||
let attentionMask = new Array(inputTokenIds.length).fill(1);
|
||||
|
||||
let encoderOutputs = null;
|
||||
let pastKeyValues = null;
|
||||
let outputTokenIds = [this.config.decoder_start_token_id];
|
||||
let numOutputTokens = 1;
|
||||
const maxOutputTokens = numOutputTokens + maxLength;
|
||||
|
||||
let sampler = x => this.sampleLogitsGreedily(x);
|
||||
if (topK > 0) {
|
||||
sampler = x => this.sampleLogitsTopK(x, topK);
|
||||
}
|
||||
|
||||
while (numOutputTokens < maxOutputTokens) {
|
||||
|
||||
let output = await this.forward({
|
||||
input_ids: inputTokenIds,
|
||||
attention_mask: attentionMask,
|
||||
decoder_input_ids: outputTokenIds,
|
||||
encoder_outputs: encoderOutputs,
|
||||
past_key_values: pastKeyValues,
|
||||
});
|
||||
pastKeyValues = output.pastKeyValues;
|
||||
encoderOutputs = output.encoderOutputs;
|
||||
|
||||
let newTokenId = sampler(output.logits);
|
||||
outputTokenIds.push(newTokenId);
|
||||
++numOutputTokens;
|
||||
if (newTokenId === this.config.eos_token_id) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return outputTokenIds;
|
||||
}
|
||||
|
||||
sampleLogitsGreedily(logits) {
|
||||
let shape = logits.dims;
|
||||
let [batchSize, seqLength, vocabSize] = shape;
|
||||
let n = batchSize * seqLength * vocabSize;
|
||||
let startIndex = n - vocabSize;
|
||||
let argmaxi = 0;
|
||||
let argmax = logits.data[startIndex + argmaxi];
|
||||
for (let i = 1; i < vocabSize; i++) {
|
||||
let l = logits.data[startIndex + i];
|
||||
if (l > argmax) {
|
||||
argmaxi = i;
|
||||
argmax = l;
|
||||
}
|
||||
}
|
||||
return argmaxi;
|
||||
}
|
||||
|
||||
sampleLogitsTopK(logits, k) {
|
||||
let shape = logits.dims;
|
||||
let [batchSize, seqLength, vocabSize] = shape;
|
||||
let n = batchSize * seqLength * vocabSize;
|
||||
let startIndex = n - vocabSize;
|
||||
let logs = logits.data.slice(startIndex);
|
||||
k = Math.min(k, vocabSize);
|
||||
let logitAndId = Array.from(logs).map((x, i) => [x, i])
|
||||
.sort((a, b) => b[0] - a[0]);
|
||||
const sMin = Math.exp(-100.0);
|
||||
let sumS = 0.0;
|
||||
for (let i = 0; i < logitAndId.length; i++) {
|
||||
const s = i < k ? Math.exp(logitAndId[i][0]) : sMin;
|
||||
sumS += s;
|
||||
logitAndId[i][0] = s;
|
||||
}
|
||||
let r = Math.random() * sumS;
|
||||
for (let i = 0; i < logitAndId.length; i++) {
|
||||
r -= logitAndId[i][0];
|
||||
if (r <= 0) {
|
||||
return logitAndId[i][1];
|
||||
}
|
||||
}
|
||||
return logitAndId[0][1];
|
||||
//////////////////////////////////////////////////
|
||||
// GPT2 models
|
||||
class GPT2PreTrainedModel extends PreTrainedModel { }
|
||||
class GPT2Model extends GPT2PreTrainedModel {
|
||||
async generate(...args) {
|
||||
throw Error(
|
||||
"The current model class (GPT2Model) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'T5ForConditionalGeneration'}"
|
||||
)
|
||||
}
|
||||
}
|
||||
class GPT2LMHeadModel extends GPT2PreTrainedModel { }
|
||||
// class GPT2ForSequenceClassification extends GPT2PreTrainedModel {
|
||||
// TODO
|
||||
// }
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
class Seq2SeqLMOutput {
|
||||
|
@ -329,6 +347,15 @@ class Seq2SeqLMOutput {
|
|||
}
|
||||
}
|
||||
|
||||
class ClassificationOutput {
|
||||
constructor(modelConfig, logits) {
|
||||
this.logits = logits;
|
||||
this.prediction = indexOfMax(logits);
|
||||
this.score = softmax(logits)[this.prediction];
|
||||
this.label = modelConfig.id2label[this.prediction];
|
||||
}
|
||||
}
|
||||
|
||||
export {
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
|
|
Loading…
Reference in New Issue