[GPTJ] enable common tests and few fixes (#14190)

* enable common tests, small fixes

* don't tie word embeds

* don't ignore lm_head
This commit is contained in:
Suraj Patil 2021-11-01 22:38:52 +05:30 committed by GitHub
parent 70d5711848
commit ce91bf9a34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 10 deletions

View File

@ -109,6 +109,7 @@ class GPTJConfig(PretrainedConfig):
use_cache=True, use_cache=True,
bos_token_id=50256, bos_token_id=50256,
eos_token_id=50256, eos_token_id=50256,
tie_word_embeddings=False,
**kwargs **kwargs
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -130,4 +131,6 @@ class GPTJConfig(PretrainedConfig):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)

View File

@ -71,7 +71,7 @@ class GPTJAttention(nn.Module):
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
self.register_buffer( self.register_buffer(
"bias", "bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
) )
@ -136,7 +136,7 @@ class GPTJAttention(nn.Module):
# compute causal mask from causal mask buffer # compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
# Keep the attention weights computation in fp32 to avoid overflow issues # Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32) query = query.to(torch.float32)
@ -674,7 +674,7 @@ class GPTJModel(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING, GPTJ_START_DOCSTRING,
) )
class GPTJForCausalLM(GPTJPreTrainedModel): class GPTJForCausalLM(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ -707,10 +707,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_output_embeddings(self): def get_output_embeddings(self):
return None return self.lm_head
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
return self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
@ -847,13 +847,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING, GPTJ_START_DOCSTRING,
) )
class GPTJForSequenceClassification(GPTJPreTrainedModel): class GPTJForSequenceClassification(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.transformer = GPTJModel(config) self.transformer = GPTJModel(config)
self.score = nn.Linear(config.n_positions, self.num_labels, bias=False) self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
self.init_weights() self.init_weights()

View File

@ -21,7 +21,8 @@ from transformers import GPTJConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, tooslow, torch_device from transformers.testing_utils import require_torch, slow, tooslow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available(): if is_torch_available():
@ -350,7 +351,7 @@ class GPTJModelTester:
@require_torch @require_torch
class GPTJModelTest(unittest.TestCase): class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else () all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
@ -358,6 +359,7 @@ class GPTJModelTest(unittest.TestCase):
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
test_model_parallel = False test_model_parallel = False
test_head_masking = False
# special case for DoubleHeads model # special case for DoubleHeads model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):