[GPTJ] enable common tests and few fixes (#14190)
* enable common tests, small fixes * don't tie word embeds * don't ignore lm_head
This commit is contained in:
parent
70d5711848
commit
ce91bf9a34
|
@ -109,6 +109,7 @@ class GPTJConfig(PretrainedConfig):
|
|||
use_cache=True,
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
@ -130,4 +131,6 @@ class GPTJConfig(PretrainedConfig):
|
|||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
||||
)
|
||||
|
|
|
@ -71,7 +71,7 @@ class GPTJAttention(nn.Module):
|
|||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
|
@ -136,7 +136,7 @@ class GPTJAttention(nn.Module):
|
|||
|
||||
# compute causal mask from causal mask buffer
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
query = query.to(torch.float32)
|
||||
|
@ -674,7 +674,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||
GPTJ_START_DOCSTRING,
|
||||
)
|
||||
class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -707,10 +707,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return None
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
|
@ -847,13 +847,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||
GPTJ_START_DOCSTRING,
|
||||
)
|
||||
class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = GPTJModel(config)
|
||||
self.score = nn.Linear(config.n_positions, self.num_labels, bias=False)
|
||||
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
|
|
@ -21,7 +21,8 @@ from transformers import GPTJConfig, is_torch_available
|
|||
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -350,7 +351,7 @@ class GPTJModelTester:
|
|||
|
||||
|
||||
@require_torch
|
||||
class GPTJModelTest(unittest.TestCase):
|
||||
class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else ()
|
||||
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
|
||||
|
@ -358,6 +359,7 @@ class GPTJModelTest(unittest.TestCase):
|
|||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
test_model_parallel = False
|
||||
test_head_masking = False
|
||||
|
||||
# special case for DoubleHeads model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
|
Loading…
Reference in New Issue