[Phi2] Add support for phi2 models (#28211)

* modified script and added test for phi2

* changes
This commit is contained in:
Susnato Dhar 2024-01-07 12:49:14 +05:30 committed by GitHub
parent 4ab5fb8941
commit 3eddda1111
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 29 deletions

View File

@ -18,12 +18,15 @@ Weights conversion script for Phi
This script downloads both Phi-1 and Phi-1.5 checkpoints to "checkpoint_path" and then converts the weights to
HugfgingFace model's format and saves them in "pytorch_dump_folder_path".
Example : $python ./convert_phi_weights_to_hf.py --model_name "microsoft/phi-2" --pytorch_dump_folder ./dump_folder/ --checkpoint_path ./ckpt_path/
"""
import argparse
import gc
import os
import safetensors
import torch
from huggingface_hub import hf_hub_download
@ -31,18 +34,21 @@ from transformers import PhiConfig, PhiForCausalLM
_MODELS = {
"microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin",
"microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin",
"microsoft/phi-1": ["https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin"],
"microsoft/phi-1_5": ["https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin"],
"microsoft/phi-2": [
"https://huggingface.co/microsoft/phi-2/blob/main/model-00001-of-00002.safetensors",
"https://huggingface.co/microsoft/phi-2/blob/main/model-00002-of-00002.safetensors",
],
}
PHI_MAPPING = {
"layers.0.wte.weight": "model.embed_tokens.weight",
"layers.25.linear.bias": "lm_head.bias",
"layers.25.linear.weight": "lm_head.weight",
"layers.25.ln.bias": "model.final_layernorm.bias",
"layers.25.ln.weight": "model.final_layernorm.weight",
"transformer.embd.wte.weight": "model.embed_tokens.weight",
"lm_head.linear": "lm_head",
"lm_head.ln": "model.final_layernorm",
"layers": "model.layers",
"transformer": "model",
".h.": ".layers.",
"ln": "input_layernorm",
"mixer": "self_attn",
"Wqkv": "query_key_value",
@ -54,14 +60,6 @@ def convert_weights(original_weights, mapping, config):
converted_weights = {}
original_weights_keys = sorted(original_weights.keys())
# we change names (1-24) -> layers(0-23) for Phi model layers
range_change = {
f"layers.{k}.": f"layers.{v}."
for k, v in zip(range(1, config.num_hidden_layers + 1), range(0, config.num_hidden_layers))
}
mapping.update(**range_change)
for original_weights_key in original_weights_keys:
new_key = original_weights_key
@ -104,27 +102,48 @@ def _download(url: str, root: str):
)
def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly):
def convert_phi_weights(
model_name, checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly, _MODELS
):
_MODELS = _MODELS if model_name not in _MODELS.keys() else {model_name: _MODELS.get(model_name)}
device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
for each_model_name, each_model_url in _MODELS.items():
for model_name, model_url in _MODELS.items():
converted_checkpoint = {}
model_checkpoint = {}
model_path = os.path.join(checkpoint_path, each_model_name + "_" + each_model_url.split("/")[-1])
if not os.path.exists(model_path):
print(f"\n{each_model_name} was not found! Downloading it to {model_path}")
_download(url=each_model_url, root=model_path)
model_checkpoint = torch.load(model_path, map_location=device)
model_type = each_model_name.split("/")[1] # phi-1 or phi-1_5
config = PhiConfig.from_pretrained(f"susnato/{model_type}_dev")
# for phi-2 the weights are stored in 2 different safetensors file so we need to iterate over that list and download one at a time
for model_each_url in model_url:
model_path = os.path.join(checkpoint_path, model_name + "_" + model_each_url.split("/")[-1])
if not os.path.exists(model_path):
print(f"\n{model_name} was not found! Downloading it to {model_path}")
_download(url=model_each_url, root=model_path)
if model_path.endswith("safetensors"):
loaded_weights = safetensors.torch.load_file(model_path, device=device)
else:
loaded_weights = torch.load(model_path, map_location=device)
model_checkpoint.update(**loaded_weights)
model_type = model_name.split("/")[1] # phi-1 or phi-1_5 or phi-2
# init the config for phi-1 and phi-1.5
config = PhiConfig()
# if we are dealing with phi-2 then update the config
if model_type == "phi-2":
config.hidden_size = 2560
config.intermediate_size = 10240
config.num_hidden_layers = 32
config.resid_pdrop = 0.1
config.partial_rotary_factor = 0.4
config.num_hidden_layers = 32
config.torch_dtype = "float16"
# Converting the weights
converted_checkpoint.update(**convert_weights(model_checkpoint, PHI_MAPPING, config))
# Save either the whole model or the converted weights
if save_weights_directly:
save_weights_path = os.path.join(
pytorch_dump_folder_path, each_model_name.split("/")[-1] + "_" + each_model_url.split("/")[-1]
)
save_weights_path = os.path.join(pytorch_dump_folder_path, model_type + "_pytorch_model.bin")
torch.save(converted_checkpoint, save_weights_path)
print(f"Model weights saved at {save_weights_path}!")
@ -148,6 +167,12 @@ def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, sav
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument(
"--model_name",
type=str,
help="Name of the model to convert. (Please enter one of the following: phi-1, phi-1_5, phi-2). If nothing is provided, all models will be converted.",
default=None,
)
parser.add_argument(
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
)
@ -172,4 +197,11 @@ if __name__ == "__main__":
)
args = parser.parse_args()
convert_phi_weights(args.checkpoint_path, args.pytorch_dump_folder_path, args.use_cuda, args.save_weights_directly)
convert_phi_weights(
args.model_name,
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.use_cuda,
args.save_weights_directly,
_MODELS,
)

View File

@ -432,3 +432,38 @@ class PhiIntegrationTest(unittest.TestCase):
EXPECTED_OUTPUT = torch.tensor([[12.2922, 13.3507, 8.6963, 9.1355, 9.3502, 9.2667, 14.2027, 13.1363, 13.5446, 11.1337, 9.9279, 16.7195, 13.0768, 14.9141, 11.9965, 8.0233, 10.3129, 10.6118, 10.0204, 9.3827, 8.8344, 8.2806, 8.0153, 8.0540, 7.0964, 16.5743, 11.1256, 9.6987, 11.4770, 10.5440], [12.3323, 14.6050, 8.9986, 8.1580, 9.5654, 6.6728, 12.5966, 12.6662, 12.2784, 11.7522, 8.2039, 16.3102, 11.2203, 13.6088, 12.0125, 9.1021, 9.8216, 10.0987, 9.0926, 8.4260, 8.8009, 7.6547, 6.8075, 7.7881, 7.4501, 15.7451, 10.5053, 8.3129, 10.0027, 9.2612]]).to(torch_device) # fmt: skip
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-4, rtol=1e-4))
def test_model_phi_2_logits(self):
input_ids = {
"input_ids": torch.tensor(
[[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
)
}
model = PhiForCausalLM.from_pretrained("susnato/phi-2").to(torch_device)
model.eval()
output = model(**input_ids).logits
EXPECTED_OUTPUT = torch.tensor([[6.4830, 6.1644, 3.4055, 2.2848, 5.4654, 2.8360, 5.5975, 5.5391, 7.3101, 4.2498, 2.5913, 10.3885, 6.4359, 8.7982, 5.6534, 0.5150, 2.7498, 3.1930, 2.4334, 1.7781, 1.5613, 1.3067, 0.8291, 0.5633, 0.6522, 9.8191, 5.5771, 2.7987, 4.2845, 3.7030], [6.0642, 7.8242, 3.4634, 1.9259, 4.3169, 2.0913, 6.0446, 3.6804, 6.6736, 4.0727, 2.1791, 11.4139, 5.6795, 7.5652, 6.2039, 2.7174, 4.3266, 3.6930, 2.8058, 2.6721, 2.3047, 2.0848, 2.0972, 2.0441, 1.3160, 9.2085, 4.5557, 3.0296, 2.6045, 2.4059]]).to(torch_device) # fmt: skip
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3))
def test_phi_2_generation(self):
model = PhiForCausalLM.from_pretrained("susnato/phi-2")
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2")
inputs = tokenizer(
"Can you help me write a formal email to a potential business partner proposing a joint venture?",
return_tensors="pt",
return_attention_mask=False,
)
outputs = model.generate(**inputs, max_new_tokens=30)
output_text = tokenizer.batch_decode(outputs)
EXPECTED_OUTPUT = [
"Can you help me write a formal email to a potential business partner proposing a joint venture?\nInput: Company A: ABC Inc.\nCompany B: XYZ Ltd.\nJoint Venture: A new online platform for e-commerce"
]
self.assertListEqual(output_text, EXPECTED_OUTPUT)