[GPTJ] enable common tests and few fixes (#14190)
* enable common tests, small fixes * don't tie word embeds * don't ignore lm_head
This commit is contained in:
parent
70d5711848
commit
ce91bf9a34
|
@ -109,6 +109,7 @@ class GPTJConfig(PretrainedConfig):
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
bos_token_id=50256,
|
bos_token_id=50256,
|
||||||
eos_token_id=50256,
|
eos_token_id=50256,
|
||||||
|
tie_word_embeddings=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
@ -130,4 +131,6 @@ class GPTJConfig(PretrainedConfig):
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(
|
||||||
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
||||||
|
)
|
||||||
|
|
|
@ -71,7 +71,7 @@ class GPTJAttention(nn.Module):
|
||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||||
1, 1, max_positions, max_positions
|
1, 1, max_positions, max_positions
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -136,7 +136,7 @@ class GPTJAttention(nn.Module):
|
||||||
|
|
||||||
# compute causal mask from causal mask buffer
|
# compute causal mask from causal mask buffer
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||||
|
|
||||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||||
query = query.to(torch.float32)
|
query = query.to(torch.float32)
|
||||||
|
@ -674,7 +674,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||||
GPTJ_START_DOCSTRING,
|
GPTJ_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class GPTJForCausalLM(GPTJPreTrainedModel):
|
class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
|
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
@ -707,10 +707,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return None
|
return self.lm_head
|
||||||
|
|
||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
return
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
|
@ -847,13 +847,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||||
GPTJ_START_DOCSTRING,
|
GPTJ_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
self.transformer = GPTJModel(config)
|
self.transformer = GPTJModel(config)
|
||||||
self.score = nn.Linear(config.n_positions, self.num_labels, bias=False)
|
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,8 @@ from transformers import GPTJConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
|
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from .test_generation_utils import GenerationTesterMixin
|
||||||
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
@ -350,7 +351,7 @@ class GPTJModelTester:
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GPTJModelTest(unittest.TestCase):
|
class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
|
||||||
|
@ -358,6 +359,7 @@ class GPTJModelTest(unittest.TestCase):
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
|
test_head_masking = False
|
||||||
|
|
||||||
# special case for DoubleHeads model
|
# special case for DoubleHeads model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
|
Loading…
Reference in New Issue