[`GPTNeoX`] Fix BC issue with 4.36 (#28602)
* fix dtype issue * add a test * update copied from mentions * nits * fixup * fix copies * Apply suggestions from code review
This commit is contained in:
parent
344943b88a
commit
8e3e145b42
|
@ -526,8 +526,8 @@ def attention_mask_func(attention_scores, ltor_mask):
|
|||
return attention_scores
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LlamaRotary->GPTNeoXRotary
|
||||
class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
|
@ -549,8 +549,8 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
|
@ -558,15 +558,15 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len],
|
||||
self.sin_cached[:seq_len],
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX
|
||||
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
@ -579,14 +579,14 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
|||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
|
||||
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
@ -606,8 +606,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
|||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
|
|
|
@ -235,6 +235,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
|||
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
|
||||
class RotaryEmbedding(nn.Module):
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
|
@ -256,8 +257,8 @@ class RotaryEmbedding(nn.Module):
|
|||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
|
@ -265,8 +266,8 @@ class RotaryEmbedding(nn.Module):
|
|||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len],
|
||||
self.sin_cached[:seq_len],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -355,3 +355,13 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
|||
output_str = tokenizer.batch_decode(output_ids)[0]
|
||||
|
||||
self.assertEqual(output_str, expected_output)
|
||||
|
||||
def pythia_integration_test(self):
|
||||
model_name_or_path = "EleutherAI/pythia-70m"
|
||||
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)
|
||||
EXPECTED_LOGITS = torch.tensor([1069.0000, 228.7500, 1072.0000, 1072.0000, 1069.0000, 1068.0000, 1068.0000, 1071.0000, 1071.0000, 1071.0000, 1073.0000, 1070.0000, 1071.0000, 1075.0000, 1073.0000, 1075.0000, 1074.0000, 1069.0000, 1072.0000, 1071.0000, 1071.0000, 1071.0000, 1070.0000, 1069.0000, 1069.0000, 1069.0000, 1070.0000, 1075.0000, 1073.0000, 1074.0000]) # fmt: skip
|
||||
input_ids = [29, 93, 303, 64, 5478, 49651, 10394, 187, 34, 12939, 875]
|
||||
# alternative: tokenizer('<|im_start|>system\nA chat between')
|
||||
input_ids = torch.as_tensor(input_ids)[None].to(torch_device)
|
||||
outputs = model(input_ids)["logits"][:, -1][0, :30]
|
||||
self.assertTrue(torch.allclose(EXPECTED_LOGITS, outputs, atol=1e-5))
|
||||
|
|
Loading…
Reference in New Issue