fixing weights initialization in the model and out of span clamping

This commit is contained in:
thomwolf 2018-11-06 17:26:33 +01:00
parent 907d3569c1
commit 2a97fe220b
1 changed files with 7 additions and 7 deletions

View File

@ -388,10 +388,10 @@ class BertForSequenceClassification(nn.Module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(config.initializer_range)
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(config.initializer_range)
module.gamma.data.normal_(config.initializer_range)
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
@ -438,10 +438,10 @@ class BertForQuestionAnswering(nn.Module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(config.initializer_range)
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(config.initializer_range)
module.gamma.data.normal_(config.initializer_range)
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
@ -459,7 +459,7 @@ class BertForQuestionAnswering(nn.Module):
start_positions = start_positions.squeeze(-1)
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) + 1
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)