
73 lines
2.4 KiB

import os
from pathlib import Path
from typing import Dict, List
import fire
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.utils.logging import get_logger
logger = get_logger(__name__)
def remove_prefix(text: str, prefix: str):
if text.startswith(prefix):
return text[len(prefix) :]
return text # or whatever
def sanitize(sd):
return {remove_prefix(k, "model."): v for k, v in sd.items()}
def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]):
new_sd = {}
for k in state_dicts[0].keys():
tensors = [sd[k] for sd in state_dicts]
new_t = sum(tensors) / len(tensors)
assert isinstance(new_t, torch.Tensor)
new_sd[k] = new_t
return new_sd
def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None:
"""Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict.
Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once!
pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files.
If a directory is passed, all .ckpt files inside it will be averaged!
hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint
save_path (:obj:`str`): Directory to save the new model
hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir)
if os.path.isfile(pl_ckpt_path):
ckpt_files = [pl_ckpt_path]
assert os.path.isdir(pl_ckpt_path)
ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt"))
assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory"
if len(ckpt_files) > 1:"averaging the weights of {ckpt_files}")
state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files]
state_dict = average_state_dicts(state_dicts)
missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
assert not missing, f"missing keys: {missing}"
tok = AutoTokenizer.from_pretrained(hf_src_model_dir)
except Exception:
# dont copy tokenizer if cant
if __name__ == "__main__":