Improve support of conversational models (#658)
* Add `return_full_text` option for text-generation models * [wip] Support chat inputs in text-generation pipeline * Align return type with python version * Remove conversational task (moved to text-generation) * Fix typos
This commit is contained in:
parent
aa542cf548
commit
c2c45cb577
|
@ -198,7 +198,6 @@ You can refine your search by selecting the task you're interested in (e.g., [te
|
|||
|
||||
| Task | ID | Description | Supported? |
|
||||
|--------------------------|----|-------------|------------|
|
||||
| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
|
||||
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
|
||||
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
|
||||
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
| Task | ID | Description | Supported? |
|
||||
|--------------------------|----|-------------|------------|
|
||||
| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
|
||||
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
|
||||
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
|
||||
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |
|
||||
|
|
|
@ -841,18 +841,24 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC
|
|||
}
|
||||
}
|
||||
|
||||
function isChat(x) {
|
||||
return Array.isArray(x) && x.every(x => 'role' in x && 'content' in x);
|
||||
}
|
||||
|
||||
/**
|
||||
* @typedef {import('./tokenizers.js').Message[]} Chat
|
||||
*
|
||||
* @typedef {Object} TextGenerationSingle
|
||||
* @property {string} generated_text The generated text.
|
||||
* @property {string|Chat} generated_text The generated text.
|
||||
* @typedef {TextGenerationSingle[]} TextGenerationOutput
|
||||
*
|
||||
* @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines.
|
||||
* @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences.
|
||||
* @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned.
|
||||
* @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig
|
||||
*
|
||||
* @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs.
|
||||
* @param {string|string[]} texts One or several prompts (or one list of prompts) to complete.
|
||||
* @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete.
|
||||
* @param {TextGenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
|
||||
* @returns {Promise<TextGenerationOutput|TextGenerationOutput[]>} An array or object containing the generated texts.
|
||||
*
|
||||
|
@ -921,17 +927,46 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
|
|||
|
||||
/** @type {TextGenerationPipelineCallback} */
|
||||
async _call(texts, generate_kwargs = {}) {
|
||||
let isBatched = false;
|
||||
let isChatInput = false;
|
||||
|
||||
const isBatched = Array.isArray(texts);
|
||||
if (!isBatched) {
|
||||
texts = [/** @type {string}*/ (texts)];
|
||||
// Normalize inputs
|
||||
/** @type {string[]} */
|
||||
let inputs;
|
||||
if (typeof texts === 'string') {
|
||||
inputs = texts = [texts];
|
||||
} else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) {
|
||||
isBatched = true;
|
||||
inputs = /** @type {string[]} */(texts);
|
||||
} else {
|
||||
if (isChat(texts)) {
|
||||
texts = [/** @type {Chat} */(texts)];
|
||||
} else if (Array.isArray(texts) && texts.every(isChat)) {
|
||||
isBatched = true;
|
||||
} else {
|
||||
throw new Error('Input must be a string, an array of strings, a Chat, or an array of Chats');
|
||||
}
|
||||
isChatInput = true;
|
||||
|
||||
// If the input is a chat, we need to apply the chat template
|
||||
inputs = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map(
|
||||
x => this.tokenizer.apply_chat_template(x, {
|
||||
tokenize: false,
|
||||
add_generation_prompt: true,
|
||||
})
|
||||
));
|
||||
}
|
||||
|
||||
// By default, do not add special tokens
|
||||
const add_special_tokens = generate_kwargs.add_special_tokens ?? false;
|
||||
|
||||
// By default, return full text
|
||||
const return_full_text = isChatInput
|
||||
? false
|
||||
: generate_kwargs.return_full_text ?? true;
|
||||
|
||||
this.tokenizer.padding_side = 'left';
|
||||
const { input_ids, attention_mask } = this.tokenizer(texts, {
|
||||
const { input_ids, attention_mask } = this.tokenizer(inputs, {
|
||||
add_special_tokens,
|
||||
padding: true,
|
||||
truncation: true,
|
||||
|
@ -941,17 +976,34 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
|
|||
inputs_attention_mask: attention_mask
|
||||
});
|
||||
|
||||
const decoded = this.tokenizer.batch_decode(outputTokenIds, {
|
||||
let decoded = this.tokenizer.batch_decode(outputTokenIds, {
|
||||
skip_special_tokens: true,
|
||||
});
|
||||
|
||||
|
||||
let promptLengths;
|
||||
if (!return_full_text && input_ids.dims.at(-1) > 0) {
|
||||
promptLengths = this.tokenizer.batch_decode(input_ids, {
|
||||
skip_special_tokens: true,
|
||||
}).map(x => x.length);
|
||||
}
|
||||
|
||||
/** @type {TextGenerationOutput[]} */
|
||||
const toReturn = Array.from({ length: texts.length }, _ => []);
|
||||
for (let i = 0; i < decoded.length; ++i) {
|
||||
const textIndex = Math.floor(i / outputTokenIds.length * texts.length);
|
||||
|
||||
if (promptLengths) {
|
||||
// Trim the decoded text to only include the generated part
|
||||
decoded[i] = decoded[i].slice(promptLengths[textIndex]);
|
||||
}
|
||||
toReturn[textIndex].push({
|
||||
generated_text: decoded[i]
|
||||
generated_text: isChatInput
|
||||
? [
|
||||
...((/** @type {Chat[]} */(texts)[textIndex])),
|
||||
{ role: 'assistant', content: decoded[i] },
|
||||
]
|
||||
: decoded[i]
|
||||
});
|
||||
}
|
||||
return (!isBatched && toReturn.length === 1) ? toReturn[0] : toReturn;
|
||||
|
|
|
@ -2429,6 +2429,12 @@ function truncateHelper(item, length) {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* @typedef {Object} Message
|
||||
* @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
|
||||
* @property {string} content The content of the message.
|
||||
*/
|
||||
|
||||
export class PreTrainedTokenizer extends Callable {
|
||||
return_token_type_ids = false;
|
||||
|
||||
|
@ -2959,12 +2965,6 @@ export class PreTrainedTokenizer extends Callable {
|
|||
return this._default_chat_template;
|
||||
}
|
||||
|
||||
/**
|
||||
* @typedef {Object} Message
|
||||
* @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
|
||||
* @property {string} content The content of the message.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Converts a list of message objects with `"role"` and `"content"` keys to a list of token
|
||||
* ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
|
||||
|
|
|
@ -11,6 +11,8 @@ describe('Generation parameters', () => {
|
|||
const models = [
|
||||
'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder
|
||||
'MBZUAI/LaMini-GPT-124M', // decoder-only
|
||||
|
||||
'Xenova/llama2.c-stories15M', // decoder-only
|
||||
];
|
||||
|
||||
// encoder-decoder model
|
||||
|
@ -135,4 +137,37 @@ describe('Generation parameters', () => {
|
|||
|
||||
}, MAX_TEST_EXECUTION_TIME);
|
||||
|
||||
// decoder-only model
|
||||
it(models[2], async () => {
|
||||
const MAX_NEW_TOKENS = 1;
|
||||
|
||||
const text = [
|
||||
'Once upon a time,',
|
||||
'Lily',
|
||||
'Suddenly,',
|
||||
];
|
||||
|
||||
const generator = await pipeline('text-generation', m(models[2]));
|
||||
|
||||
{ // return_full_text=false
|
||||
const output = await generator(text, {
|
||||
return_full_text: false,
|
||||
max_new_tokens: MAX_NEW_TOKENS,
|
||||
num_beams: 2,
|
||||
num_return_sequences: 2,
|
||||
});
|
||||
const lengths = output.flatMap(
|
||||
x => x.flatMap(
|
||||
y => generator.tokenizer.encode(y.generated_text.trim(), null, {
|
||||
add_special_tokens: false,
|
||||
}).length
|
||||
)
|
||||
).every(x => x === MAX_NEW_TOKENS);
|
||||
|
||||
expect(lengths).toBe(true);
|
||||
}
|
||||
await generator.dispose();
|
||||
|
||||
}, MAX_TEST_EXECUTION_TIME);
|
||||
|
||||
});
|
Loading…
Reference in New Issue