From 89c510d842a58e5d45e27129bec5c35c97951e1f Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 24 Apr 2024 10:11:19 +0200 Subject: [PATCH] Add llama3 (#30334) * nuke * add co-author * add co-author * update card * fixup and fix copies to please our ci * nit fixup * super small nits * remove tokenizer_path from call to `write_model` * always safe serialize by default --------- Co-authored-by: pcuenca Co-authored-by: xenova --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/llama3.md | 85 ++++++++++++ src/transformers/convert_slow_tokenizer.py | 93 +++++++++++++ .../models/auto/configuration_auto.py | 1 + .../llama/convert_llama_weights_to_hf.py | 124 ++++++++++++++---- utils/check_table.py | 1 + 7 files changed, 279 insertions(+), 28 deletions(-) create mode 100644 docs/source/en/model_doc/llama3.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 45f51886b7..47001a365e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -394,6 +394,8 @@ title: LLaMA - local: model_doc/llama2 title: Llama2 + - local: model_doc/llama3 + title: Llama3 - local: model_doc/longformer title: Longformer - local: model_doc/longt5 diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 912bbad1d2..3c136ea465 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -177,6 +177,7 @@ Flax), PyTorch, and/or TensorFlow. | [LiLT](model_doc/lilt) | ✅ | ❌ | ❌ | | [LLaMA](model_doc/llama) | ✅ | ❌ | ✅ | | [Llama2](model_doc/llama2) | ✅ | ❌ | ✅ | +| [Llama3](model_doc/llama3) | ✅ | ❌ | ✅ | | [LLaVa](model_doc/llava) | ✅ | ❌ | ❌ | | [LLaVA-NeXT](model_doc/llava_next) | ✅ | ❌ | ❌ | | [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/llama3.md b/docs/source/en/model_doc/llama3.md new file mode 100644 index 0000000000..1a7546c7e6 --- /dev/null +++ b/docs/source/en/model_doc/llama3.md @@ -0,0 +1,85 @@ + + +# Llama3 + + +## Overview + +The Llama3 model was proposed in [Introducing Meta Llama 3: The most capable openly available LLM to date](https://ai.meta.com/blog/meta-llama-3/) by the meta AI team. + +The abstract from the blogpost is the following: + +*Today, we’re excited to share the first two models of the next generation of Llama, Meta Llama 3, available for broad use. This release features pretrained and instruction-fine-tuned language models with 8B and 70B parameters that can support a broad range of use cases. This next generation of Llama demonstrates state-of-the-art performance on a wide range of industry benchmarks and offers new capabilities, including improved reasoning. We believe these are the best open source models of their class, period. In support of our longstanding open approach, we’re putting Llama 3 in the hands of the community. We want to kickstart the next wave of innovation in AI across the stack—from applications to developer tools to evals to inference optimizations and more. We can’t wait to see what you build and look forward to your feedback.* + +Checkout all Llama3 model checkpoints [here](https://huggingface.co/models?search=llama3). +The original code of the authors can be found [here](https://github.com/meta-llama/llama3). + +## Usage tips + + + +The `Llama3` models were trained using `bfloat16`, but the original inference uses `float16`. The checkpoints uploaded on the Hub use `torch_dtype = 'float16'`, which will be +used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`. + +The `dtype` of the online weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online), then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`), and finally, if there is a `torch_dtype` provided in the config, it will be used. + +Training the model in `float16` is not recommended and is known to produce `nan`; as such, the model should be trained in `bfloat16`. + + + +Tips: + +- Weights for the Llama3 models can be obtained by filling out [this form](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) +- The architecture is exactly the same as Llama2. +- The tokenizer is a BPE model based on [tiktoken](https://github.com/openai/tiktoken) (vs the one based on sentencepiece implementation for Llama2). The main difference that it ignores BPE merge rules when an input token is part of the vocab. This means that if no merge exist to produce `"hugging"`, instead of having the smallest units, like `["hug","ging"] form 2 tokens, if `"hugging"` is part of the vocab, it will be automatically returned as a token. +- The original model uses `pad_id = -1` which means that there is no padding token. We can't have the same logic, make sure to add a padding token using `tokenizer.add_special_tokens({"pad_token":""})` and resize the token embedding accordingly. You should also set the `model.config.pad_token_id`. The `embed_tokens` layer of the model is initialized with `self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.config.padding_idx)`, which makes sure that encoding the padding token will output zeros, so passing it when initializing is recommended. +- The original checkpoint can be converted using the [conversion script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py). The script can be called with the following (example) command: + +```bash +python src/transformers/models/llama/convert_llama_weights_to_hf.py \ + --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path --llama_version 3 +``` + +- After conversion, the model and tokenizer can be loaded via: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("/output/path") +model = AutoModelForCausalLM.from_pretrained("/output/path") +``` + +Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 75B model, it's thus 145GB of RAM needed. + + +- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type. + +## Quick usage + +```py3 +import transformers +import torch + +model_id = "meta-llama/Meta-Llama-3-8B" + +pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto") +pipeline("Hey how are you doing today?") +``` + +## Resources +A ton of cool resources are already available on the documentation page of [~llama2], inviting contributors to add new recourses curated for Llama3 here! 🤗 diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 8cb80c22cd..39c239d145 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1450,6 +1450,99 @@ class MarkupLMConverter(Converter): return tokenizer +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +class TikTokenConverter: + """ + A general tiktoken converter. + """ + + def __init__( + self, + vocab_file=None, + pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", + add_prefix_space=False, + *args, + ): + super().__init__(*args) + self.vocab_file = vocab_file + self.pattern = pattern + self.add_prefix_space = add_prefix_space + + def extract_vocab_merges_from_model(self, tiktoken_url: str): + try: + from tiktoken.load import load_tiktoken_bpe + except Exception: + raise ValueError( + "`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`." + ) + + bpe_ranks = load_tiktoken_bpe(tiktoken_url) + byte_encoder = bytes_to_unicode() + + def token_bytes_to_string(b): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + merges = [] + vocab = {} + for token, rank in bpe_ranks.items(): + vocab[token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + local = [] + for index in range(1, len(token)): + piece_l, piece_r = token[:index], token[index:] + if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks: + local.append((piece_l, piece_r, rank)) + local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False) + merges.extend(local) + merges = sorted(merges, key=lambda val: val[2], reverse=False) + merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges] + return vocab, merges + + def tokenizer(self): + vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file) + tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False)) + if hasattr(tokenizer.model, "ignore_merges"): + tokenizer.model.ignore_merges = True + return tokenizer + + def converted(self) -> Tokenizer: + tokenizer = self.tokenizer() + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False), + pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False), + ] + ) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) + return tokenizer + + SLOW_TO_FAST_CONVERTERS = { "AlbertTokenizer": AlbertConverter, "BartTokenizer": RobertaConverter, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 29a52ba755..d6361ee791 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -412,6 +412,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("lilt", "LiLT"), ("llama", "LLaMA"), ("llama2", "Llama2"), + ("llama3", "Llama3"), ("llava", "LLaVa"), ("llava_next", "LLaVA-NeXT"), ("longformer", "Longformer"), diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index f9bca1204a..a98d44b748 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -20,7 +20,8 @@ import warnings import torch -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast +from transformers.convert_slow_tokenizer import TikTokenConverter try: @@ -51,10 +52,31 @@ tokenizer = LlamaTokenizer.from_pretrained("/output/path") Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). + +If you want you tokenizer to add a bos automatically you should update the tokenizer._tokenizers.post_processor: + +```py +from tokenizers import processors +bos = "<|begin_of_text|>" +tokenizer._tokenizers.post_processor = processors.Sequence( + [ + processors.ByteLevel(trim_offsets=False), + processors.TemplateProcessing( + single=f"{bos}:0 $A:0", + pair=f"{bos}:0 $A:0 {bos}:1 $B:1", + special_tokens=[ + (bos, tokenizer.encode(bos)), + ], + ), + ] +) +``` """ NUM_SHARDS = { "7B": 1, + "8B": 1, + "8Bf": 1, "7Bf": 1, "13B": 2, "13Bf": 2, @@ -81,7 +103,12 @@ def write_json(text, path): def write_model( - model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1 + model_path, + input_base_path, + model_size, + safe_serialization=True, + llama_version=1, + vocab_size=None, ): # for backward compatibility, before you needed the repo to be called `my_repo/model_size` if not os.path.isfile(os.path.join(input_base_path, "params.json")): @@ -101,7 +128,7 @@ def write_model( dims_per_head = dim // n_heads base = params.get("rope_theta", 10000.0) inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - if base > 10000.0: + if base > 10000.0 and llama_version != 3: max_position_embeddings = 16384 else: # Depending on the Llama version, the default max_position_embeddings has different values. @@ -109,18 +136,10 @@ def write_model( max_position_embeddings = 2048 elif llama_version == 2: max_position_embeddings = 4096 - else: - raise NotImplementedError( - f"Version {llama_version} of llama is not supported yet. " - "Current supported versions of llama are [1, 2]." - ) - - tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast - if tokenizer_path is not None: - tokenizer = tokenizer_class(tokenizer_path) - tokenizer.save_pretrained(model_path) - vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 + elif llama_version == 3: + max_position_embeddings = 8192 + vocab_size = vocab_size if vocab_size is not None else 32000 if params.get("n_kv_heads", None) is not None: num_key_value_heads = params["n_kv_heads"] # for GQA / MQA num_local_key_value_heads = n_heads_per_shard // num_key_value_heads @@ -131,7 +150,7 @@ def write_model( key_value_dim = dim # permute for sliced rotary - def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): + def permute(w, n_heads, dim1=dim, dim2=dim): return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) print(f"Fetching all parameters from the checkpoint at {input_base_path}.") @@ -154,10 +173,12 @@ def write_model( # Unsharded state_dict = { f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( - loaded[f"layers.{layer_i}.attention.wq.weight"] + loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads ), f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( - loaded[f"layers.{layer_i}.attention.wk.weight"] + loaded[f"layers.{layer_i}.attention.wk.weight"], + n_heads=num_key_value_heads, + dim1=dim // num_local_key_value_heads, ), f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], @@ -188,7 +209,8 @@ def write_model( for i in range(num_shards) ], dim=0, - ).reshape(dim, dim) + ).reshape(dim, dim), + n_heads=n_heads, ) state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( torch.cat( @@ -242,10 +264,11 @@ def write_model( "lm_head.weight": loaded["output.weight"], } else: + concat_dim = 0 if llama_version == 3 else 1 state_dict = { "model.norm.weight": loaded[0]["norm.weight"], "model.embed_tokens.weight": torch.cat( - [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 + [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=concat_dim ), "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), } @@ -270,6 +293,8 @@ def write_model( vocab_size=vocab_size, rope_theta=base, max_position_embeddings=max_position_embeddings, + bos_token_id=128000 if llama_version == 3 else 1, + eos_token_id=128001 if llama_version == 3 else 2, ) config.save_pretrained(tmp_model_path) @@ -288,12 +313,54 @@ def write_model( shutil.rmtree(tmp_model_path) -def write_tokenizer(tokenizer_path, input_tokenizer_path): - # Initialize the tokenizer based on the `spm` model +class Llama3Converter(TikTokenConverter): + def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs): + super().__init__(vocab_file, **kwargs) + tokenizer = self.converted() + chat_template = ( + "{% set loop_messages = messages %}" + "{% for message in loop_messages %}" + "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}" + "{% if loop.index0 == 0 %}" + "{% set content = bos_token + content %}" + "{% endif %}" + "{{ content }}" + "{% endfor %}" + "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" + ) + num_reserved_special_tokens = 256 + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)] + tokenizer.add_special_tokens(special_tokens) + + self.tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>", + chat_template=chat_template, + model_input_names=["input_ids", "attention_mask"], + ) + + +def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2): tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + if llama_version == 3: + tokenizer = Llama3Converter(input_tokenizer_path).tokenizer + else: + tokenizer = tokenizer_class(input_tokenizer_path) print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") - tokenizer = tokenizer_class(input_tokenizer_path) tokenizer.save_pretrained(tokenizer_path) + return tokenizer def main(): @@ -304,35 +371,36 @@ def main(): ) parser.add_argument( "--model_size", - choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"], + choices=["7B", "8B", "8Bf", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"], help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", ) parser.add_argument( "--output_dir", help="Location to write HF model and tokenizer", ) - parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + parser.add_argument( + "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`." + ) # Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. parser.add_argument( "--llama_version", - choices=[1, 2], + choices=[1, 2, 3], default=1, type=int, help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size", ) args = parser.parse_args() spm_path = os.path.join(args.input_dir, "tokenizer.model") + vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version)) if args.model_size != "tokenizer_only": write_model( model_path=args.output_dir, input_base_path=args.input_dir, model_size=args.model_size, safe_serialization=args.safe_serialization, - tokenizer_path=spm_path, llama_version=args.llama_version, + vocab_size=vocab_size, ) - else: - write_tokenizer(args.output_dir, spm_path) if __name__ == "__main__": diff --git a/utils/check_table.py b/utils/check_table.py index 99031f025c..9c9318ca85 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -155,6 +155,7 @@ MODEL_NAMES_WITH_SAME_CONFIG = { "HerBERT": "BERT", "LayoutXLM": "LayoutLMv2", "Llama2": "LLaMA", + "Llama3": "LLaMA", "MADLAD-400": "T5", "MatCha": "Pix2Struct", "mBART-50": "mBART",