This commit is contained in:
Patrick von Platen 2021-09-29 10:30:00 +02:00 committed by GitHub
parent a21ee1f990
commit aa018a795d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 0 deletions

View File

@ -964,6 +964,14 @@ class HubertForCTC(HubertPreTrainedModel):
self.hubert = HubertModel(config)
self.dropout = nn.Dropout(config.final_dropout)
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `HubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
"or define `vocab_size` of your model's configuration."
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()

View File

@ -1416,6 +1416,14 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.final_dropout)
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`."
"or define `vocab_size` of your model's configuration."
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()