[GPTJ] enable common tests and few fixes (#14190)

* enable common tests, small fixes

* don't tie word embeds

* don't ignore lm_head
This commit is contained in:
Suraj Patil 2021-11-01 22:38:52 +05:30 committed by GitHub
parent 70d5711848
commit ce91bf9a34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 10 deletions

View File

@ -109,6 +109,7 @@ class GPTJConfig(PretrainedConfig):
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
tie_word_embeddings=False,
**kwargs
):
self.vocab_size = vocab_size
@ -130,4 +131,6 @@ class GPTJConfig(PretrainedConfig):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)

View File

@ -71,7 +71,7 @@ class GPTJAttention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
@ -136,7 +136,7 @@ class GPTJAttention(nn.Module):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
@ -674,7 +674,7 @@ class GPTJModel(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING,
)
class GPTJForCausalLM(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config):
super().__init__(config)
@ -707,10 +707,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
torch.cuda.empty_cache()
def get_output_embeddings(self):
return None
return self.lm_head
def set_output_embeddings(self, new_embeddings):
return
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
@ -847,13 +847,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING,
)
class GPTJForSequenceClassification(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GPTJModel(config)
self.score = nn.Linear(config.n_positions, self.num_labels, bias=False)
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
self.init_weights()

View File

@ -21,7 +21,8 @@ from transformers import GPTJConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@ -350,7 +351,7 @@ class GPTJModelTester:
@require_torch
class GPTJModelTest(unittest.TestCase):
class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
@ -358,6 +359,7 @@ class GPTJModelTest(unittest.TestCase):
test_pruning = False
test_missing_keys = False
test_model_parallel = False
test_head_masking = False
# special case for DoubleHeads model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):