Support decoding of tensors (#416)

* Support decoding of tensors (Closes #362)

* Remove debug line
This commit is contained in:
Joshua Lochner 2023-12-02 16:17:57 +02:00 committed by GitHub
parent 768a2e26d7
commit 3da3841811
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 2 deletions

View File

@ -122,6 +122,26 @@ function objectToMap(obj) {
return new Map(Object.entries(obj));
}
/**
* Helper function to convert a tensor to a list before decoding.
* @param {Tensor} tensor The tensor to convert.
* @returns {number[]} The tensor as a list.
*/
function prepareTensorForDecode(tensor) {
const dims = tensor.dims;
switch (dims.length) {
case 1:
return tensor.tolist();
case 2:
if (dims[0] !== 1) {
throw new Error('Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.');
}
return tensor.tolist()[0];
default:
throw new Error(`Expected tensor to have 1-2 dimensions, got ${dims.length}.`)
}
}
/**
* Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms
* @param {string} text The text to clean up.
@ -2556,18 +2576,21 @@ export class PreTrainedTokenizer extends Callable {
/**
* Decode a batch of tokenized sequences.
* @param {number[][]} batch List of tokenized input sequences.
* @param {number[][]|Tensor} batch List/Tensor of tokenized input sequences.
* @param {Object} decode_args (Optional) Object with decoding arguments.
* @returns {string[]} List of decoded sequences.
*/
batch_decode(batch, decode_args = {}) {
if (batch instanceof Tensor) {
batch = batch.tolist();
}
return batch.map(x => this.decode(x, decode_args));
}
/**
* Decodes a sequence of token IDs back to a string.
*
* @param {number[]} token_ids List of token IDs to decode.
* @param {number[]|Tensor} token_ids List/Tensor of token IDs to decode.
* @param {Object} [decode_args={}]
* @param {boolean} [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string.
* @param {boolean} [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed.
@ -2579,6 +2602,10 @@ export class PreTrainedTokenizer extends Callable {
token_ids,
decode_args = {},
) {
if (token_ids instanceof Tensor) {
token_ids = prepareTensorForDecode(token_ids);
}
if (!Array.isArray(token_ids) || token_ids.length === 0 || !isIntegralNumber(token_ids[0])) {
throw Error("token_ids must be a non-empty array of integers.");
}
@ -3458,6 +3485,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
let text;
// @ts-ignore
if (decode_args && decode_args.decode_with_timestamps) {
if (token_ids instanceof Tensor) {
token_ids = prepareTensorForDecode(token_ids);
}
text = this.decodeWithTimestamps(token_ids, decode_args);
} else {
text = super.decode(token_ids, decode_args);

View File

@ -57,3 +57,27 @@ describe('Edge cases', () => {
compare(token_ids, [101, 100, 102])
}, 5000); // NOTE: 5 seconds
});
describe('Extra decoding tests', () => {
it('should be able to decode the output of encode', async () => {
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
let text = 'hello world!';
// Ensure all the following outputs are the same:
// 1. Tensor of ids: allow decoding of 1D or 2D tensors.
let encodedTensor = tokenizer(text);
let decoded1 = tokenizer.decode(encodedTensor.input_ids, { skip_special_tokens: true });
let decoded2 = tokenizer.batch_decode(encodedTensor.input_ids, { skip_special_tokens: true })[0];
expect(decoded1).toEqual(text);
expect(decoded2).toEqual(text);
// 2. List of ids
let encodedList = tokenizer(text, { return_tensor: false });
let decoded3 = tokenizer.decode(encodedList.input_ids, { skip_special_tokens: true });
let decoded4 = tokenizer.batch_decode([encodedList.input_ids], { skip_special_tokens: true })[0];
expect(decoded3).toEqual(text);
expect(decoded4).toEqual(text);
}, MAX_TEST_EXECUTION_TIME);
});