fixing weights initialization in the model and out of span clamping
This commit is contained in:
parent
907d3569c1
commit
2a97fe220b
14
modeling.py
14
modeling.py
|
@ -388,10 +388,10 @@ class BertForSequenceClassification(nn.Module):
|
|||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BERTLayerNorm):
|
||||
module.beta.data.normal_(config.initializer_range)
|
||||
module.gamma.data.normal_(config.initializer_range)
|
||||
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.bias.data.zero_()
|
||||
self.apply(init_weights)
|
||||
|
@ -438,10 +438,10 @@ class BertForQuestionAnswering(nn.Module):
|
|||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BERTLayerNorm):
|
||||
module.beta.data.normal_(config.initializer_range)
|
||||
module.gamma.data.normal_(config.initializer_range)
|
||||
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.bias.data.zero_()
|
||||
self.apply(init_weights)
|
||||
|
@ -459,7 +459,7 @@ class BertForQuestionAnswering(nn.Module):
|
|||
start_positions = start_positions.squeeze(-1)
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1) + 1
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
|
|
Loading…
Reference in New Issue