Create `LongT5` classes

This commit is contained in:
Joshua Lochner 2023-09-19 01:12:14 +02:00
parent b3a2a5b00f
commit d840fe6287
1 changed files with 41 additions and 0 deletions

View File

@ -1891,6 +1891,47 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel {
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// LONGT5 models
/**
* An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
*/
export class LongT5PreTrainedModel extends PreTrainedModel { };
/**
* The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.
*/
export class LongT5Model extends LongT5PreTrainedModel { }
/**
* LONGT5 Model with a `language modeling` head on top.
*/
export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel {
/**
* Creates a new instance of the `LongT5ForConditionalGeneration` class.
* @param {Object} config The model configuration.
* @param {any} session session for the model.
* @param {any} decoder_merged_session session for the decoder.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session, decoder_merged_session, generation_config) {
super(config, session);
this.decoder_merged_session = decoder_merged_session;
this.generation_config = generation_config;
this.num_decoder_layers = this.config.num_decoder_layers;
this.num_decoder_heads = this.config.num_heads;
this.decoder_dim_kv = this.config.d_kv;
this.num_encoder_layers = this.config.num_layers;
this.num_encoder_heads = this.config.num_heads;
this.encoder_dim_kv = this.config.d_kv;
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// MT5 models
export class MT5PreTrainedModel extends PreTrainedModel { };