From d840fe628761e7cce4b345453c07982fb0d174c1 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 19 Sep 2023 01:12:14 +0200 Subject: [PATCH] Create `LongT5` classes --- src/models.js | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/models.js b/src/models.js index daf33fd..6c11eab 100644 --- a/src/models.js +++ b/src/models.js @@ -1891,6 +1891,47 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { } ////////////////////////////////////////////////// + +////////////////////////////////////////////////// +// LONGT5 models +/** + * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + */ +export class LongT5PreTrainedModel extends PreTrainedModel { }; + +/** + * The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top. + */ +export class LongT5Model extends LongT5PreTrainedModel { } + +/** + * LONGT5 Model with a `language modeling` head on top. + */ +export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { + /** + * Creates a new instance of the `LongT5ForConditionalGeneration` class. + * @param {Object} config The model configuration. + * @param {any} session session for the model. + * @param {any} decoder_merged_session session for the decoder. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, decoder_merged_session, generation_config) { + super(config, session); + this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; + + this.num_decoder_layers = this.config.num_decoder_layers; + this.num_decoder_heads = this.config.num_heads; + this.decoder_dim_kv = this.config.d_kv; + + this.num_encoder_layers = this.config.num_layers; + this.num_encoder_heads = this.config.num_heads; + this.encoder_dim_kv = this.config.d_kv; + } +} +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // MT5 models export class MT5PreTrainedModel extends PreTrainedModel { };