Support decoding of multiple lists of token ids

This commit is contained in:
Joshua Lochner 2023-02-19 01:37:59 +02:00
parent 807b9216fe
commit 8414b0f9d8
1 changed files with 14 additions and 0 deletions

View File

@ -692,6 +692,20 @@ class PreTrainedTokenizer extends Callable {
}
decode(token_ids, skip_special_tokens = false) {
if (!Array.isArray(token_ids) || token_ids.length === 0) {
throw Error("token_ids must be a non-empty array.");
}
if (Array.isArray(token_ids[0])) {
// array of array
return token_ids.map(x => this.decode_single(x, skip_special_tokens));
} else {
return this.decode_single(token_ids, skip_special_tokens)
}
}
decode_single(token_ids, skip_special_tokens = false) {
let tokens = this.model.convert_ids_to_tokens(token_ids);
if (skip_special_tokens) {