Do not use spread operator to concatenate large arrays (Closes #153) (#154)

* Do not use spread operator for merging large arrays (Fix #153)

* Add unit test for encoding long strings
This commit is contained in:
Joshua Lochner 2023-06-21 01:21:14 +02:00 committed by GitHub
parent 573012b434
commit 4804171180
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 7 deletions

View File

@ -25,13 +25,14 @@ import {
reverseDictionary,
escapeRegExp,
isIntegralNumber,
mergeArrays,
} from './utils/core.js';
import {
getModelJSON,
} from './utils/hub.js';
import { min } from './utils/maths.js';
import { max, min } from './utils/maths.js';
import { Tensor } from './utils/tensor.js';
/**
@ -1188,12 +1189,12 @@ class RobertaProcessing extends PostProcessor {
* @returns {string[]} The input tokens with the special tokens added to the beginning and end.
*/
post_process(tokens, tokens_pair = null) {
tokens = [this.cls, ...tokens, this.sep]
tokens = mergeArrays([this.cls], tokens, [this.sep]);
// NOTE: It is intended to add 2 EOS tokens after the first set of tokens
// https://github.com/huggingface/tokenizers/issues/983
if (tokens_pair !== null) {
tokens = [...tokens, this.sep, ...tokens_pair, this.sep]
tokens = mergeArrays(tokens, [this.sep], tokens_pair, [this.sep]);
}
return tokens;
}
@ -1233,10 +1234,10 @@ class TemplateProcessing extends PostProcessor {
} else if ('Sequence' in item) {
if (item.Sequence.id === 'A') {
toReturn.push(...tokens);
toReturn = mergeArrays(toReturn, tokens);
} else if (item.Sequence.id === 'B') {
toReturn.push(...tokens_pair);
toReturn = mergeArrays(toReturn, tokens_pair);
}
}
}
@ -1961,7 +1962,7 @@ export class PreTrainedTokenizer extends Callable {
// At this point, tokens is batched: [batch_size, tokens]
// However, array may be jagged. So, we pad to max_length
let maxLengthOfBatch = Math.max(...tokens.map(x => x.length));
let maxLengthOfBatch = max(tokens.map(x => x.length))[0];
// If null, we calculate max length from sequences
if (max_length === null) {
@ -2917,7 +2918,7 @@ export class MarianTokenizer extends PreTrainedTokenizer {
if (!this.supported_language_codes.includes(language)) {
console.warn(`Unsupported language code "${language}" detected, which may lead to unexpected behavior. Should be one of: ${JSON.stringify(this.supported_language_codes)}`)
}
return [language, ...super._encode_text(text)]
return mergeArrays([language], super._encode_text(text));
}
}

View File

@ -152,3 +152,13 @@ export function pop(obj, key, defaultValue = undefined) {
}
return defaultValue;
}
/**
* Efficiently merge arrays, creating a new copy.
* Adapted from https://stackoverflow.com/a/6768642/13989043
* @param {...any} arrs Arrays to merge.
* @returns The merged array.
*/
export function mergeArrays(...arrs) {
return Array.prototype.concat.apply([], arrs);
}

View File

@ -33,3 +33,13 @@ describe('Tokenizers', () => {
});
}
});
describe('Edge cases', () => {
it('should not crash when encoding a very long string', async () => {
let tokenizer = await AutoTokenizer.from_pretrained('t5-small');
let text = String.prototype.repeat.call('Hello world! ', 50000);
let encoded = await tokenizer(text);
expect(encoded.input_ids.data.length).toBeGreaterThan(100000);
});
});