Improve support of conversational models (#658)

* Add `return_full_text` option for text-generation models

* [wip] Support chat inputs in text-generation pipeline

* Align return type with python version

* Remove conversational task (moved to text-generation)

* Fix typos
This commit is contained in:
Joshua Lochner 2024-04-11 00:58:50 +02:00 committed by GitHub
parent aa542cf548
commit c2c45cb577
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 101 additions and 16 deletions

View File

@ -198,7 +198,6 @@ You can refine your search by selecting the task you're interested in (e.g., [te
| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |

View File

@ -5,7 +5,6 @@
| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |

View File

@ -841,18 +841,24 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC
}
}
function isChat(x) {
return Array.isArray(x) && x.every(x => 'role' in x && 'content' in x);
}
/**
* @typedef {import('./tokenizers.js').Message[]} Chat
*
* @typedef {Object} TextGenerationSingle
* @property {string} generated_text The generated text.
* @property {string|Chat} generated_text The generated text.
* @typedef {TextGenerationSingle[]} TextGenerationOutput
*
* @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines.
* @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences.
* @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned.
* @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig
*
* @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs.
* @param {string|string[]} texts One or several prompts (or one list of prompts) to complete.
* @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete.
* @param {TextGenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise<TextGenerationOutput|TextGenerationOutput[]>} An array or object containing the generated texts.
*
@ -921,17 +927,46 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
/** @type {TextGenerationPipelineCallback} */
async _call(texts, generate_kwargs = {}) {
let isBatched = false;
let isChatInput = false;
const isBatched = Array.isArray(texts);
if (!isBatched) {
texts = [/** @type {string}*/ (texts)];
// Normalize inputs
/** @type {string[]} */
let inputs;
if (typeof texts === 'string') {
inputs = texts = [texts];
} else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) {
isBatched = true;
inputs = /** @type {string[]} */(texts);
} else {
if (isChat(texts)) {
texts = [/** @type {Chat} */(texts)];
} else if (Array.isArray(texts) && texts.every(isChat)) {
isBatched = true;
} else {
throw new Error('Input must be a string, an array of strings, a Chat, or an array of Chats');
}
isChatInput = true;
// If the input is a chat, we need to apply the chat template
inputs = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map(
x => this.tokenizer.apply_chat_template(x, {
tokenize: false,
add_generation_prompt: true,
})
));
}
// By default, do not add special tokens
const add_special_tokens = generate_kwargs.add_special_tokens ?? false;
// By default, return full text
const return_full_text = isChatInput
? false
: generate_kwargs.return_full_text ?? true;
this.tokenizer.padding_side = 'left';
const { input_ids, attention_mask } = this.tokenizer(texts, {
const { input_ids, attention_mask } = this.tokenizer(inputs, {
add_special_tokens,
padding: true,
truncation: true,
@ -941,17 +976,34 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
inputs_attention_mask: attention_mask
});
const decoded = this.tokenizer.batch_decode(outputTokenIds, {
let decoded = this.tokenizer.batch_decode(outputTokenIds, {
skip_special_tokens: true,
});
let promptLengths;
if (!return_full_text && input_ids.dims.at(-1) > 0) {
promptLengths = this.tokenizer.batch_decode(input_ids, {
skip_special_tokens: true,
}).map(x => x.length);
}
/** @type {TextGenerationOutput[]} */
const toReturn = Array.from({ length: texts.length }, _ => []);
for (let i = 0; i < decoded.length; ++i) {
const textIndex = Math.floor(i / outputTokenIds.length * texts.length);
if (promptLengths) {
// Trim the decoded text to only include the generated part
decoded[i] = decoded[i].slice(promptLengths[textIndex]);
}
toReturn[textIndex].push({
generated_text: decoded[i]
generated_text: isChatInput
? [
...((/** @type {Chat[]} */(texts)[textIndex])),
{ role: 'assistant', content: decoded[i] },
]
: decoded[i]
});
}
return (!isBatched && toReturn.length === 1) ? toReturn[0] : toReturn;

View File

@ -2429,6 +2429,12 @@ function truncateHelper(item, length) {
}
/**
* @typedef {Object} Message
* @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
* @property {string} content The content of the message.
*/
export class PreTrainedTokenizer extends Callable {
return_token_type_ids = false;
@ -2959,12 +2965,6 @@ export class PreTrainedTokenizer extends Callable {
return this._default_chat_template;
}
/**
* @typedef {Object} Message
* @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
* @property {string} content The content of the message.
*/
/**
* Converts a list of message objects with `"role"` and `"content"` keys to a list of token
* ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to

View File

@ -11,6 +11,8 @@ describe('Generation parameters', () => {
const models = [
'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder
'MBZUAI/LaMini-GPT-124M', // decoder-only
'Xenova/llama2.c-stories15M', // decoder-only
];
// encoder-decoder model
@ -135,4 +137,37 @@ describe('Generation parameters', () => {
}, MAX_TEST_EXECUTION_TIME);
// decoder-only model
it(models[2], async () => {
const MAX_NEW_TOKENS = 1;
const text = [
'Once upon a time,',
'Lily',
'Suddenly,',
];
const generator = await pipeline('text-generation', m(models[2]));
{ // return_full_text=false
const output = await generator(text, {
return_full_text: false,
max_new_tokens: MAX_NEW_TOKENS,
num_beams: 2,
num_return_sequences: 2,
});
const lengths = output.flatMap(
x => x.flatMap(
y => generator.tokenizer.encode(y.generated_text.trim(), null, {
add_special_tokens: false,
}).length
)
).every(x => x === MAX_NEW_TOKENS);
expect(lengths).toBe(true);
}
await generator.dispose();
}, MAX_TEST_EXECUTION_TIME);
});