Add temperate and discount factor parameters to generate function
This commit is contained in:
parent
6a17ba43bf
commit
2f16b4a595
|
@ -135,9 +135,11 @@ class PreTrainedModel extends Callable {
|
|||
max_length: 100,
|
||||
top_k: 0,
|
||||
num_beams: 1,
|
||||
temperature: 1,
|
||||
num_return_sequences: 1,
|
||||
early_stopping: false,
|
||||
do_sample: false,
|
||||
discount_factor: 1,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -209,6 +211,10 @@ class PreTrainedModel extends Callable {
|
|||
|
||||
// update new beam
|
||||
this.updateBeam(newBeam, newTokenId)
|
||||
|
||||
if (options.discount_factor < 1) {
|
||||
newBeam.score *= options.discount_factor;
|
||||
}
|
||||
newBeam.score += logProb;
|
||||
|
||||
if (newTokenId === this.config.eos_token_id) {
|
||||
|
@ -243,10 +249,18 @@ class PreTrainedModel extends Callable {
|
|||
|
||||
// TODO add beam
|
||||
if (options.num_beams > 1) {
|
||||
sampler = new BeamSearchSampler(options.num_beams, options.do_sample, options.top_k)
|
||||
sampler = new BeamSearchSampler(
|
||||
options.num_beams,
|
||||
options.do_sample,
|
||||
options.top_k,
|
||||
options.temperature
|
||||
)
|
||||
|
||||
} else if (options.top_k > 0) {
|
||||
sampler = new TopKSampler(options.top_k)
|
||||
sampler = new TopKSampler(
|
||||
options.top_k,
|
||||
options.temperature
|
||||
)
|
||||
} else {
|
||||
sampler = new GreedySampler()
|
||||
}
|
||||
|
|
|
@ -27,6 +27,13 @@ class Sampler extends Callable {
|
|||
return logs;
|
||||
}
|
||||
|
||||
addTemperature(logits, temperature){
|
||||
if (temperature >= 1) {
|
||||
return logits;
|
||||
}
|
||||
return logits.map(x => x * temperature);
|
||||
}
|
||||
|
||||
getTopLogits(logits, top_k = 0) {
|
||||
// if top == 0, return all
|
||||
|
||||
|
@ -72,10 +79,11 @@ class GreedySampler extends Sampler {
|
|||
}
|
||||
|
||||
class TopKSampler extends Sampler {
|
||||
constructor(k) {
|
||||
constructor(k, temperature) {
|
||||
super();
|
||||
|
||||
this.k = k;
|
||||
this.temperature = temperature;
|
||||
}
|
||||
|
||||
sample(logits) {
|
||||
|
@ -83,6 +91,7 @@ class TopKSampler extends Sampler {
|
|||
let k = Math.min(this.k, vocabSize);
|
||||
|
||||
let logs = this.getLastLogits(logits);
|
||||
logs = this.addTemperature(logs, this.temperature);
|
||||
|
||||
// Get top k tokens
|
||||
let topLogits = this.getTopLogits(logs, k);
|
||||
|
@ -101,21 +110,23 @@ class TopKSampler extends Sampler {
|
|||
}
|
||||
|
||||
class BeamSearchSampler extends Sampler {
|
||||
constructor(num_beams, do_sample, top_k) {
|
||||
constructor(num_beams, do_sample, top_k, temperature) {
|
||||
super();
|
||||
this.num_beams = num_beams; // maximum number of beams
|
||||
this.do_sample = do_sample; // if true, perform multinomial sampling
|
||||
|
||||
this.top_k = top_k; // if do_sample, sample from top k items
|
||||
this.temperature = temperature;
|
||||
}
|
||||
|
||||
sample(logits) {
|
||||
|
||||
const logs = this.getLastLogits(logits);
|
||||
let logs = this.getLastLogits(logits);
|
||||
logs = this.addTemperature(logs, this.temperature);
|
||||
|
||||
if (this.do_sample) {
|
||||
if (this.do_sample || this.top_k > 0) {
|
||||
const [batchSize, seqLength, vocabSize] = logits.dims;
|
||||
const k = Math.min(this.k, vocabSize);
|
||||
const k = Math.min(this.top_k, vocabSize);
|
||||
const topLogits = this.getTopLogits(logs, k);
|
||||
|
||||
// Compute softmax over top k logits
|
||||
|
|
Loading…
Reference in New Issue