138 lines
4.9 KiB
JavaScript
138 lines
4.9 KiB
JavaScript
|
|
import { pipeline } from '../src/transformers.js';
|
|
import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js';
|
|
|
|
// Initialise the testing environment
|
|
init();
|
|
|
|
describe('Generation parameters', () => {
|
|
|
|
// List all models which will be tested
|
|
const models = [
|
|
'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder
|
|
'MBZUAI/LaMini-GPT-124M', // decoder-only
|
|
];
|
|
|
|
// encoder-decoder model
|
|
it(models[0], async () => {
|
|
const text = 'how can I become more healthy?';
|
|
|
|
const generator = await pipeline('text2text-generation', m(models[0]));
|
|
|
|
// default
|
|
// NOTE: Since `max_length` defaults to 20, this case also tests that.
|
|
{
|
|
const outputs = await generator(text);
|
|
|
|
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
|
|
expect(tokens.length).toEqual(20);
|
|
}
|
|
|
|
// max_new_tokens
|
|
{
|
|
// NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 64 tokens are generated.
|
|
// So, the following tests are valid.
|
|
const MAX_NEW_TOKENS = 20;
|
|
const outputs = await generator(text, {
|
|
max_new_tokens: MAX_NEW_TOKENS,
|
|
});
|
|
|
|
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
|
|
expect(tokens.length).toEqual(MAX_NEW_TOKENS + 1); // + 1 due to forced BOS token
|
|
}
|
|
|
|
// min_length
|
|
{
|
|
// NOTE: Without setting `min_length` (but setting `max_new_tokens`), 64 tokens are generated.
|
|
// So, the following tests are valid.
|
|
const MAX_NEW_TOKENS = 128;
|
|
const MIN_LENGTH = 65;
|
|
const outputs = await generator(text, {
|
|
max_new_tokens: MAX_NEW_TOKENS,
|
|
min_length: MIN_LENGTH,
|
|
});
|
|
|
|
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
|
|
expect(tokens.length).toBeGreaterThanOrEqual(MIN_LENGTH);
|
|
}
|
|
|
|
// min_new_tokens
|
|
{
|
|
// NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 64 tokens are generated.
|
|
// So, the following tests are valid.
|
|
const MAX_NEW_TOKENS = 128;
|
|
const MIN_NEW_TOKENS = 65;
|
|
const outputs = await generator(text, {
|
|
max_new_tokens: MAX_NEW_TOKENS,
|
|
min_new_tokens: MIN_NEW_TOKENS,
|
|
});
|
|
|
|
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
|
|
expect(tokens.length).toBeGreaterThanOrEqual(MIN_NEW_TOKENS);
|
|
}
|
|
|
|
await generator.dispose();
|
|
|
|
}, MAX_TEST_EXECUTION_TIME);
|
|
|
|
// decoder-only model
|
|
it(models[1], async () => {
|
|
const text = "### Instruction:\nTrue or False: The earth is flat?\n\n### Response: ";
|
|
|
|
const generator = await pipeline('text-generation', m(models[1]));
|
|
|
|
// default
|
|
// NOTE: Since `max_length` defaults to 20, this case also tests that.
|
|
{
|
|
const outputs = await generator(text);
|
|
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
|
|
expect(tokens.length).toEqual(20);
|
|
}
|
|
|
|
// max_new_tokens
|
|
{
|
|
const MAX_NEW_TOKENS = 20;
|
|
const outputs = await generator(text, {
|
|
max_new_tokens: MAX_NEW_TOKENS,
|
|
});
|
|
const promptTokens = generator.tokenizer.encode(text)
|
|
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
|
|
expect(tokens.length).toBeGreaterThan(promptTokens.length);
|
|
}
|
|
|
|
// min_length
|
|
{
|
|
// NOTE: Without setting `min_length` (but setting `max_new_tokens`), 22 tokens are generated.
|
|
// So, the following tests are valid.
|
|
const MAX_NEW_TOKENS = 10;
|
|
const MIN_LENGTH = 25;
|
|
const outputs = await generator(text, {
|
|
max_new_tokens: MAX_NEW_TOKENS,
|
|
min_length: MIN_LENGTH,
|
|
});
|
|
|
|
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
|
|
expect(tokens.length).toBeGreaterThanOrEqual(MIN_LENGTH);
|
|
}
|
|
|
|
// min_new_tokens
|
|
{
|
|
// NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 22 tokens are generated.
|
|
// So, the following tests are valid.
|
|
const MAX_NEW_TOKENS = 32;
|
|
const MIN_NEW_TOKENS = 10;
|
|
const outputs = await generator(text, {
|
|
max_new_tokens: MAX_NEW_TOKENS,
|
|
min_new_tokens: MIN_NEW_TOKENS,
|
|
});
|
|
|
|
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
|
|
const promptTokens = generator.tokenizer.encode(text)
|
|
expect(tokens.length).toBeGreaterThanOrEqual(promptTokens.length + MIN_NEW_TOKENS);
|
|
}
|
|
|
|
await generator.dispose();
|
|
|
|
}, MAX_TEST_EXECUTION_TIME);
|
|
|
|
}); |