Support decoding of tensors (#416)
* Support decoding of tensors (Closes #362) * Remove debug line
This commit is contained in:
parent
768a2e26d7
commit
3da3841811
|
@ -122,6 +122,26 @@ function objectToMap(obj) {
|
|||
return new Map(Object.entries(obj));
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to convert a tensor to a list before decoding.
|
||||
* @param {Tensor} tensor The tensor to convert.
|
||||
* @returns {number[]} The tensor as a list.
|
||||
*/
|
||||
function prepareTensorForDecode(tensor) {
|
||||
const dims = tensor.dims;
|
||||
switch (dims.length) {
|
||||
case 1:
|
||||
return tensor.tolist();
|
||||
case 2:
|
||||
if (dims[0] !== 1) {
|
||||
throw new Error('Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.');
|
||||
}
|
||||
return tensor.tolist()[0];
|
||||
default:
|
||||
throw new Error(`Expected tensor to have 1-2 dimensions, got ${dims.length}.`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms
|
||||
* @param {string} text The text to clean up.
|
||||
|
@ -2556,18 +2576,21 @@ export class PreTrainedTokenizer extends Callable {
|
|||
|
||||
/**
|
||||
* Decode a batch of tokenized sequences.
|
||||
* @param {number[][]} batch List of tokenized input sequences.
|
||||
* @param {number[][]|Tensor} batch List/Tensor of tokenized input sequences.
|
||||
* @param {Object} decode_args (Optional) Object with decoding arguments.
|
||||
* @returns {string[]} List of decoded sequences.
|
||||
*/
|
||||
batch_decode(batch, decode_args = {}) {
|
||||
if (batch instanceof Tensor) {
|
||||
batch = batch.tolist();
|
||||
}
|
||||
return batch.map(x => this.decode(x, decode_args));
|
||||
}
|
||||
|
||||
/**
|
||||
* Decodes a sequence of token IDs back to a string.
|
||||
*
|
||||
* @param {number[]} token_ids List of token IDs to decode.
|
||||
* @param {number[]|Tensor} token_ids List/Tensor of token IDs to decode.
|
||||
* @param {Object} [decode_args={}]
|
||||
* @param {boolean} [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string.
|
||||
* @param {boolean} [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed.
|
||||
|
@ -2579,6 +2602,10 @@ export class PreTrainedTokenizer extends Callable {
|
|||
token_ids,
|
||||
decode_args = {},
|
||||
) {
|
||||
if (token_ids instanceof Tensor) {
|
||||
token_ids = prepareTensorForDecode(token_ids);
|
||||
}
|
||||
|
||||
if (!Array.isArray(token_ids) || token_ids.length === 0 || !isIntegralNumber(token_ids[0])) {
|
||||
throw Error("token_ids must be a non-empty array of integers.");
|
||||
}
|
||||
|
@ -3458,6 +3485,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
|
|||
let text;
|
||||
// @ts-ignore
|
||||
if (decode_args && decode_args.decode_with_timestamps) {
|
||||
if (token_ids instanceof Tensor) {
|
||||
token_ids = prepareTensorForDecode(token_ids);
|
||||
}
|
||||
text = this.decodeWithTimestamps(token_ids, decode_args);
|
||||
} else {
|
||||
text = super.decode(token_ids, decode_args);
|
||||
|
|
|
@ -57,3 +57,27 @@ describe('Edge cases', () => {
|
|||
compare(token_ids, [101, 100, 102])
|
||||
}, 5000); // NOTE: 5 seconds
|
||||
});
|
||||
|
||||
describe('Extra decoding tests', () => {
|
||||
it('should be able to decode the output of encode', async () => {
|
||||
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
|
||||
|
||||
let text = 'hello world!';
|
||||
|
||||
// Ensure all the following outputs are the same:
|
||||
// 1. Tensor of ids: allow decoding of 1D or 2D tensors.
|
||||
let encodedTensor = tokenizer(text);
|
||||
let decoded1 = tokenizer.decode(encodedTensor.input_ids, { skip_special_tokens: true });
|
||||
let decoded2 = tokenizer.batch_decode(encodedTensor.input_ids, { skip_special_tokens: true })[0];
|
||||
expect(decoded1).toEqual(text);
|
||||
expect(decoded2).toEqual(text);
|
||||
|
||||
// 2. List of ids
|
||||
let encodedList = tokenizer(text, { return_tensor: false });
|
||||
let decoded3 = tokenizer.decode(encodedList.input_ids, { skip_special_tokens: true });
|
||||
let decoded4 = tokenizer.batch_decode([encodedList.input_ids], { skip_special_tokens: true })[0];
|
||||
expect(decoded3).toEqual(text);
|
||||
expect(decoded4).toEqual(text);
|
||||
|
||||
}, MAX_TEST_EXECUTION_TIME);
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue