[`GPTNeoX`] Fix BC issue with 4.36 (#28602)

* fix dtype issue

* add a test

* update copied from mentions

* nits

* fixup

* fix copies

* Apply suggestions from code review
This commit is contained in:
Arthur 2024-01-21 18:01:19 +01:00 committed by Amy Roberts
parent 344943b88a
commit 8e3e145b42
3 changed files with 26 additions and 15 deletions

View File

@ -526,8 +526,8 @@ def attention_mask_func(attention_scores, ltor_mask):
return attention_scores
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LlamaRotary->GPTNeoXRotary
class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@ -549,8 +549,8 @@ class GPTNeoXRotaryEmbedding(nn.Module):
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
@ -558,15 +558,15 @@ class GPTNeoXRotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
self.cos_cached[:seq_len],
self.sin_cached[:seq_len],
)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
@ -579,14 +579,14 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
@ -606,8 +606,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def rotate_half(x):

View File

@ -235,6 +235,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@ -256,8 +257,8 @@ class RotaryEmbedding(nn.Module):
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
@ -265,8 +266,8 @@ class RotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
self.cos_cached[:seq_len],
self.sin_cached[:seq_len],
)

View File

@ -355,3 +355,13 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
output_str = tokenizer.batch_decode(output_ids)[0]
self.assertEqual(output_str, expected_output)
def pythia_integration_test(self):
model_name_or_path = "EleutherAI/pythia-70m"
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)
EXPECTED_LOGITS = torch.tensor([1069.0000, 228.7500, 1072.0000, 1072.0000, 1069.0000, 1068.0000, 1068.0000, 1071.0000, 1071.0000, 1071.0000, 1073.0000, 1070.0000, 1071.0000, 1075.0000, 1073.0000, 1075.0000, 1074.0000, 1069.0000, 1072.0000, 1071.0000, 1071.0000, 1071.0000, 1070.0000, 1069.0000, 1069.0000, 1069.0000, 1070.0000, 1075.0000, 1073.0000, 1074.0000]) # fmt: skip
input_ids = [29, 93, 303, 64, 5478, 49651, 10394, 187, 34, 12939, 875]
# alternative: tokenizer('<|im_start|>system\nA chat between')
input_ids = torch.as_tensor(input_ids)[None].to(torch_device)
outputs = model(input_ids)["logits"][:, -1][0, :30]
self.assertTrue(torch.allclose(EXPECTED_LOGITS, outputs, atol=1e-5))