* Do not use spread operator for merging large arrays (Fix #153) * Add unit test for encoding long strings
This commit is contained in:
parent
573012b434
commit
4804171180
|
@ -25,13 +25,14 @@ import {
|
|||
reverseDictionary,
|
||||
escapeRegExp,
|
||||
isIntegralNumber,
|
||||
mergeArrays,
|
||||
} from './utils/core.js';
|
||||
|
||||
import {
|
||||
getModelJSON,
|
||||
} from './utils/hub.js';
|
||||
|
||||
import { min } from './utils/maths.js';
|
||||
import { max, min } from './utils/maths.js';
|
||||
import { Tensor } from './utils/tensor.js';
|
||||
|
||||
/**
|
||||
|
@ -1188,12 +1189,12 @@ class RobertaProcessing extends PostProcessor {
|
|||
* @returns {string[]} The input tokens with the special tokens added to the beginning and end.
|
||||
*/
|
||||
post_process(tokens, tokens_pair = null) {
|
||||
tokens = [this.cls, ...tokens, this.sep]
|
||||
tokens = mergeArrays([this.cls], tokens, [this.sep]);
|
||||
|
||||
// NOTE: It is intended to add 2 EOS tokens after the first set of tokens
|
||||
// https://github.com/huggingface/tokenizers/issues/983
|
||||
if (tokens_pair !== null) {
|
||||
tokens = [...tokens, this.sep, ...tokens_pair, this.sep]
|
||||
tokens = mergeArrays(tokens, [this.sep], tokens_pair, [this.sep]);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
@ -1233,10 +1234,10 @@ class TemplateProcessing extends PostProcessor {
|
|||
|
||||
} else if ('Sequence' in item) {
|
||||
if (item.Sequence.id === 'A') {
|
||||
toReturn.push(...tokens);
|
||||
toReturn = mergeArrays(toReturn, tokens);
|
||||
|
||||
} else if (item.Sequence.id === 'B') {
|
||||
toReturn.push(...tokens_pair);
|
||||
toReturn = mergeArrays(toReturn, tokens_pair);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1961,7 +1962,7 @@ export class PreTrainedTokenizer extends Callable {
|
|||
// At this point, tokens is batched: [batch_size, tokens]
|
||||
// However, array may be jagged. So, we pad to max_length
|
||||
|
||||
let maxLengthOfBatch = Math.max(...tokens.map(x => x.length));
|
||||
let maxLengthOfBatch = max(tokens.map(x => x.length))[0];
|
||||
|
||||
// If null, we calculate max length from sequences
|
||||
if (max_length === null) {
|
||||
|
@ -2917,7 +2918,7 @@ export class MarianTokenizer extends PreTrainedTokenizer {
|
|||
if (!this.supported_language_codes.includes(language)) {
|
||||
console.warn(`Unsupported language code "${language}" detected, which may lead to unexpected behavior. Should be one of: ${JSON.stringify(this.supported_language_codes)}`)
|
||||
}
|
||||
return [language, ...super._encode_text(text)]
|
||||
return mergeArrays([language], super._encode_text(text));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -152,3 +152,13 @@ export function pop(obj, key, defaultValue = undefined) {
|
|||
}
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Efficiently merge arrays, creating a new copy.
|
||||
* Adapted from https://stackoverflow.com/a/6768642/13989043
|
||||
* @param {...any} arrs Arrays to merge.
|
||||
* @returns The merged array.
|
||||
*/
|
||||
export function mergeArrays(...arrs) {
|
||||
return Array.prototype.concat.apply([], arrs);
|
||||
}
|
||||
|
|
|
@ -33,3 +33,13 @@ describe('Tokenizers', () => {
|
|||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should not crash when encoding a very long string', async () => {
|
||||
let tokenizer = await AutoTokenizer.from_pretrained('t5-small');
|
||||
|
||||
let text = String.prototype.repeat.call('Hello world! ', 50000);
|
||||
let encoded = await tokenizer(text);
|
||||
expect(encoded.input_ids.data.length).toBeGreaterThan(100000);
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue