Create `LongT5` classes
This commit is contained in:
parent
b3a2a5b00f
commit
d840fe6287
|
@ -1891,6 +1891,47 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel {
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////
|
//////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////
|
||||||
|
// LONGT5 models
|
||||||
|
/**
|
||||||
|
* An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
||||||
|
*/
|
||||||
|
export class LongT5PreTrainedModel extends PreTrainedModel { };
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.
|
||||||
|
*/
|
||||||
|
export class LongT5Model extends LongT5PreTrainedModel { }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LONGT5 Model with a `language modeling` head on top.
|
||||||
|
*/
|
||||||
|
export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel {
|
||||||
|
/**
|
||||||
|
* Creates a new instance of the `LongT5ForConditionalGeneration` class.
|
||||||
|
* @param {Object} config The model configuration.
|
||||||
|
* @param {any} session session for the model.
|
||||||
|
* @param {any} decoder_merged_session session for the decoder.
|
||||||
|
* @param {GenerationConfig} generation_config The generation configuration.
|
||||||
|
*/
|
||||||
|
constructor(config, session, decoder_merged_session, generation_config) {
|
||||||
|
super(config, session);
|
||||||
|
this.decoder_merged_session = decoder_merged_session;
|
||||||
|
this.generation_config = generation_config;
|
||||||
|
|
||||||
|
this.num_decoder_layers = this.config.num_decoder_layers;
|
||||||
|
this.num_decoder_heads = this.config.num_heads;
|
||||||
|
this.decoder_dim_kv = this.config.d_kv;
|
||||||
|
|
||||||
|
this.num_encoder_layers = this.config.num_layers;
|
||||||
|
this.num_encoder_heads = this.config.num_heads;
|
||||||
|
this.encoder_dim_kv = this.config.d_kv;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////
|
//////////////////////////////////////////////////
|
||||||
// MT5 models
|
// MT5 models
|
||||||
export class MT5PreTrainedModel extends PreTrainedModel { };
|
export class MT5PreTrainedModel extends PreTrainedModel { };
|
||||||
|
|
Loading…
Reference in New Issue