From c2c45cb577a15e0facdd220ec085f8528c0b6d9e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 11 Apr 2024 00:58:50 +0200 Subject: [PATCH] Improve support of conversational models (#658) * Add `return_full_text` option for text-generation models * [wip] Support chat inputs in text-generation pipeline * Align return type with python version * Remove conversational task (moved to text-generation) * Fix typos --- README.md | 1 - docs/snippets/5_supported-tasks.snippet | 1 - src/pipelines.js | 68 ++++++++++++++++++++++--- src/tokenizers.js | 12 ++--- tests/generation.test.js | 35 +++++++++++++ 5 files changed, 101 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 755d0c5..f4b804c 100644 --- a/README.md +++ b/README.md @@ -198,7 +198,6 @@ You can refine your search by selecting the task you're interested in (e.g., [te | Task | ID | Description | Supported? | |--------------------------|----|-------------|------------| -| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ | | [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) | | [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) | | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) | diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet index ac71ee5..ee682ff 100644 --- a/docs/snippets/5_supported-tasks.snippet +++ b/docs/snippets/5_supported-tasks.snippet @@ -5,7 +5,6 @@ | Task | ID | Description | Supported? | |--------------------------|----|-------------|------------| -| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ | | [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) | | [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) | | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) | diff --git a/src/pipelines.js b/src/pipelines.js index 3b15af5..421053b 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -841,18 +841,24 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC } } +function isChat(x) { + return Array.isArray(x) && x.every(x => 'role' in x && 'content' in x); +} /** + * @typedef {import('./tokenizers.js').Message[]} Chat + * * @typedef {Object} TextGenerationSingle - * @property {string} generated_text The generated text. + * @property {string|Chat} generated_text The generated text. * @typedef {TextGenerationSingle[]} TextGenerationOutput * * @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines. * @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences. + * @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned. * @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig * * @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs. - * @param {string|string[]} texts One or several prompts (or one list of prompts) to complete. + * @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete. * @param {TextGenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise} An array or object containing the generated texts. * @@ -921,17 +927,46 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli /** @type {TextGenerationPipelineCallback} */ async _call(texts, generate_kwargs = {}) { + let isBatched = false; + let isChatInput = false; - const isBatched = Array.isArray(texts); - if (!isBatched) { - texts = [/** @type {string}*/ (texts)]; + // Normalize inputs + /** @type {string[]} */ + let inputs; + if (typeof texts === 'string') { + inputs = texts = [texts]; + } else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) { + isBatched = true; + inputs = /** @type {string[]} */(texts); + } else { + if (isChat(texts)) { + texts = [/** @type {Chat} */(texts)]; + } else if (Array.isArray(texts) && texts.every(isChat)) { + isBatched = true; + } else { + throw new Error('Input must be a string, an array of strings, a Chat, or an array of Chats'); + } + isChatInput = true; + + // If the input is a chat, we need to apply the chat template + inputs = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map( + x => this.tokenizer.apply_chat_template(x, { + tokenize: false, + add_generation_prompt: true, + }) + )); } // By default, do not add special tokens const add_special_tokens = generate_kwargs.add_special_tokens ?? false; + // By default, return full text + const return_full_text = isChatInput + ? false + : generate_kwargs.return_full_text ?? true; + this.tokenizer.padding_side = 'left'; - const { input_ids, attention_mask } = this.tokenizer(texts, { + const { input_ids, attention_mask } = this.tokenizer(inputs, { add_special_tokens, padding: true, truncation: true, @@ -941,17 +976,34 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli inputs_attention_mask: attention_mask }); - const decoded = this.tokenizer.batch_decode(outputTokenIds, { + let decoded = this.tokenizer.batch_decode(outputTokenIds, { skip_special_tokens: true, }); + + let promptLengths; + if (!return_full_text && input_ids.dims.at(-1) > 0) { + promptLengths = this.tokenizer.batch_decode(input_ids, { + skip_special_tokens: true, + }).map(x => x.length); + } + /** @type {TextGenerationOutput[]} */ const toReturn = Array.from({ length: texts.length }, _ => []); for (let i = 0; i < decoded.length; ++i) { const textIndex = Math.floor(i / outputTokenIds.length * texts.length); + if (promptLengths) { + // Trim the decoded text to only include the generated part + decoded[i] = decoded[i].slice(promptLengths[textIndex]); + } toReturn[textIndex].push({ - generated_text: decoded[i] + generated_text: isChatInput + ? [ + ...((/** @type {Chat[]} */(texts)[textIndex])), + { role: 'assistant', content: decoded[i] }, + ] + : decoded[i] }); } return (!isBatched && toReturn.length === 1) ? toReturn[0] : toReturn; diff --git a/src/tokenizers.js b/src/tokenizers.js index e671c83..ca0c2ab 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2429,6 +2429,12 @@ function truncateHelper(item, length) { } +/** + * @typedef {Object} Message + * @property {string} role The role of the message (e.g., "user" or "assistant" or "system"). + * @property {string} content The content of the message. + */ + export class PreTrainedTokenizer extends Callable { return_token_type_ids = false; @@ -2959,12 +2965,6 @@ export class PreTrainedTokenizer extends Callable { return this._default_chat_template; } - /** - * @typedef {Object} Message - * @property {string} role The role of the message (e.g., "user" or "assistant" or "system"). - * @property {string} content The content of the message. - */ - /** * Converts a list of message objects with `"role"` and `"content"` keys to a list of token * ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to diff --git a/tests/generation.test.js b/tests/generation.test.js index eb6b87f..da50388 100644 --- a/tests/generation.test.js +++ b/tests/generation.test.js @@ -11,6 +11,8 @@ describe('Generation parameters', () => { const models = [ 'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder 'MBZUAI/LaMini-GPT-124M', // decoder-only + + 'Xenova/llama2.c-stories15M', // decoder-only ]; // encoder-decoder model @@ -135,4 +137,37 @@ describe('Generation parameters', () => { }, MAX_TEST_EXECUTION_TIME); + // decoder-only model + it(models[2], async () => { + const MAX_NEW_TOKENS = 1; + + const text = [ + 'Once upon a time,', + 'Lily', + 'Suddenly,', + ]; + + const generator = await pipeline('text-generation', m(models[2])); + + { // return_full_text=false + const output = await generator(text, { + return_full_text: false, + max_new_tokens: MAX_NEW_TOKENS, + num_beams: 2, + num_return_sequences: 2, + }); + const lengths = output.flatMap( + x => x.flatMap( + y => generator.tokenizer.encode(y.generated_text.trim(), null, { + add_special_tokens: false, + }).length + ) + ).every(x => x === MAX_NEW_TOKENS); + + expect(lengths).toBe(true); + } + await generator.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + }); \ No newline at end of file