Handles multi layer and multi groups
This commit is contained in:
parent
139affaa8d
commit
12290c0d5c
|
@ -136,7 +136,6 @@ class AlbertModel(BertModel):
|
|||
head_mask=head_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
print(sequence_output.shape, sequence_output[:, 0].shape, self.pooler(sequence_output[:, 0]).shape)
|
||||
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
|
@ -260,12 +259,11 @@ class AlbertLayer(nn.Module):
|
|||
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
for _ in range(self.config.inner_group_num):
|
||||
attention_output = self.attention(hidden_states, attention_mask)[0]
|
||||
ffn_output = self.ffn(attention_output)
|
||||
ffn_output = gelu_new(ffn_output)
|
||||
ffn_output = self.ffn_output(ffn_output)
|
||||
hidden_states = self.LayerNorm(ffn_output + attention_output)
|
||||
attention_output = self.attention(hidden_states, attention_mask)[0]
|
||||
ffn_output = self.ffn(attention_output)
|
||||
ffn_output = gelu_new(ffn_output)
|
||||
ffn_output = self.ffn_output(ffn_output)
|
||||
hidden_states = self.LayerNorm(ffn_output + attention_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -303,16 +301,16 @@ class AlbertTransformer(nn.Module):
|
|||
return (hidden_states,)
|
||||
|
||||
|
||||
model_size = 'base'
|
||||
hidden_groups = 1
|
||||
inner_groups = 1
|
||||
config = AlbertConfig.from_json_file("/home/hf/google-research/albert/config_{}-{}-hg-{}-ig.json".format(model_size, hidden_groups, inner_groups))
|
||||
model = AlbertModel(config)
|
||||
# model_size = 'base'
|
||||
# hidden_groups = 1
|
||||
# inner_groups = 2
|
||||
# config = AlbertConfig.from_json_file("/home/hf/google-research/albert/config_{}-{}-hg-{}-ig.json".format(model_size, hidden_groups, inner_groups))
|
||||
# model = AlbertModel(config)
|
||||
|
||||
print(model)
|
||||
model = load_tf_weights_in_albert(model, config, "/home/hf/transformers/albert-{}-{}-hg-{}-ig/albert-{}-{}-hg-{}-ig".format(model_size, hidden_groups, inner_groups, model_size, hidden_groups, inner_groups))
|
||||
model.eval()
|
||||
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
|
||||
# # print(model)
|
||||
# model = load_tf_weights_in_albert(model, config, "/home/hf/transformers/albert-{}-{}-hg-{}-ig/albert-{}-{}-hg-{}-ig".format(model_size, hidden_groups, inner_groups, model_size, hidden_groups, inner_groups))
|
||||
# # model.eval()
|
||||
# # print(sum(p.numel() for p in model.parameters() if p.requires_grad))
|
||||
|
||||
|
||||
# input_ids = [[31, 51, 99, 88, 54, 34, 23, 23, 12], [15, 5, 0, 88, 54, 34, 23, 23, 12]]
|
||||
|
|
Loading…
Reference in New Issue