Handles multi layer and multi groups

This commit is contained in:
Lysandre 2019-10-29 23:19:02 +00:00 committed by Lysandre Debut
parent 139affaa8d
commit 12290c0d5c
1 changed files with 14 additions and 16 deletions

View File

@ -136,7 +136,6 @@ class AlbertModel(BertModel):
head_mask=head_mask)
sequence_output = encoder_outputs[0]
print(sequence_output.shape, sequence_output[:, 0].shape, self.pooler(sequence_output[:, 0]).shape)
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
@ -260,12 +259,11 @@ class AlbertLayer(nn.Module):
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
for _ in range(self.config.inner_group_num):
attention_output = self.attention(hidden_states, attention_mask)[0]
ffn_output = self.ffn(attention_output)
ffn_output = gelu_new(ffn_output)
ffn_output = self.ffn_output(ffn_output)
hidden_states = self.LayerNorm(ffn_output + attention_output)
attention_output = self.attention(hidden_states, attention_mask)[0]
ffn_output = self.ffn(attention_output)
ffn_output = gelu_new(ffn_output)
ffn_output = self.ffn_output(ffn_output)
hidden_states = self.LayerNorm(ffn_output + attention_output)
return hidden_states
@ -303,16 +301,16 @@ class AlbertTransformer(nn.Module):
return (hidden_states,)
model_size = 'base'
hidden_groups = 1
inner_groups = 1
config = AlbertConfig.from_json_file("/home/hf/google-research/albert/config_{}-{}-hg-{}-ig.json".format(model_size, hidden_groups, inner_groups))
model = AlbertModel(config)
# model_size = 'base'
# hidden_groups = 1
# inner_groups = 2
# config = AlbertConfig.from_json_file("/home/hf/google-research/albert/config_{}-{}-hg-{}-ig.json".format(model_size, hidden_groups, inner_groups))
# model = AlbertModel(config)
print(model)
model = load_tf_weights_in_albert(model, config, "/home/hf/transformers/albert-{}-{}-hg-{}-ig/albert-{}-{}-hg-{}-ig".format(model_size, hidden_groups, inner_groups, model_size, hidden_groups, inner_groups))
model.eval()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
# # print(model)
# model = load_tf_weights_in_albert(model, config, "/home/hf/transformers/albert-{}-{}-hg-{}-ig/albert-{}-{}-hg-{}-ig".format(model_size, hidden_groups, inner_groups, model_size, hidden_groups, inner_groups))
# # model.eval()
# # print(sum(p.numel() for p in model.parameters() if p.requires_grad))
# input_ids = [[31, 51, 99, 88, 54, 34, 23, 23, 12], [15, 5, 0, 88, 54, 34, 23, 23, 12]]