Add temperate and discount factor parameters to generate function

This commit is contained in:
Joshua Lochner 2023-02-19 17:15:30 +02:00
parent 6a17ba43bf
commit 2f16b4a595
2 changed files with 32 additions and 7 deletions

View File

@ -135,9 +135,11 @@ class PreTrainedModel extends Callable {
max_length: 100,
top_k: 0,
num_beams: 1,
temperature: 1,
num_return_sequences: 1,
early_stopping: false,
do_sample: false,
discount_factor: 1,
}
}
@ -209,6 +211,10 @@ class PreTrainedModel extends Callable {
// update new beam
this.updateBeam(newBeam, newTokenId)
if (options.discount_factor < 1) {
newBeam.score *= options.discount_factor;
}
newBeam.score += logProb;
if (newTokenId === this.config.eos_token_id) {
@ -243,10 +249,18 @@ class PreTrainedModel extends Callable {
// TODO add beam
if (options.num_beams > 1) {
sampler = new BeamSearchSampler(options.num_beams, options.do_sample, options.top_k)
sampler = new BeamSearchSampler(
options.num_beams,
options.do_sample,
options.top_k,
options.temperature
)
} else if (options.top_k > 0) {
sampler = new TopKSampler(options.top_k)
sampler = new TopKSampler(
options.top_k,
options.temperature
)
} else {
sampler = new GreedySampler()
}

View File

@ -27,6 +27,13 @@ class Sampler extends Callable {
return logs;
}
addTemperature(logits, temperature){
if (temperature >= 1) {
return logits;
}
return logits.map(x => x * temperature);
}
getTopLogits(logits, top_k = 0) {
// if top == 0, return all
@ -72,10 +79,11 @@ class GreedySampler extends Sampler {
}
class TopKSampler extends Sampler {
constructor(k) {
constructor(k, temperature) {
super();
this.k = k;
this.temperature = temperature;
}
sample(logits) {
@ -83,6 +91,7 @@ class TopKSampler extends Sampler {
let k = Math.min(this.k, vocabSize);
let logs = this.getLastLogits(logits);
logs = this.addTemperature(logs, this.temperature);
// Get top k tokens
let topLogits = this.getTopLogits(logs, k);
@ -101,21 +110,23 @@ class TopKSampler extends Sampler {
}
class BeamSearchSampler extends Sampler {
constructor(num_beams, do_sample, top_k) {
constructor(num_beams, do_sample, top_k, temperature) {
super();
this.num_beams = num_beams; // maximum number of beams
this.do_sample = do_sample; // if true, perform multinomial sampling
this.top_k = top_k; // if do_sample, sample from top k items
this.temperature = temperature;
}
sample(logits) {
const logs = this.getLastLogits(logits);
let logs = this.getLastLogits(logits);
logs = this.addTemperature(logs, this.temperature);
if (this.do_sample) {
if (this.do_sample || this.top_k > 0) {
const [batchSize, seqLength, vocabSize] = logits.dims;
const k = Math.min(this.k, vocabSize);
const k = Math.min(this.top_k, vocabSize);
const topLogits = this.getTopLogits(logs, k);
// Compute softmax over top k logits