Fix BLOOM dtype (#17995)

* Add fp16 option

* Fix BLOOM dtype

* Formatting

* Remove torch_dtype arg

* Revert formatting

* Apply formatting

* Add n_embed backward compat
This commit is contained in:
Niklas Muennighoff 2022-07-12 16:36:08 +02:00 committed by GitHub
parent 981714efe1
commit bc34c21191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 10 deletions

View File

@ -78,9 +78,6 @@ class BloomConfig(PretrainedConfig):
Dropout rate applied to the attention probs
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
dtype (`str`, *optional*, defaults to `"bfloat16"`):
Precision that has been used for the model's training in Megatron. Please load the model in the correct
precision by doing `model = BloomModel.from_pretrained(model_name, torch_dtype="auto")`.`
pretraining_tp (`int`, *optional*, defaults to `1`):
Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
@ -114,9 +111,7 @@ class BloomConfig(PretrainedConfig):
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_hidden_layers": "n_layer",
"n_head": "num_attention_heads",
"hidden_size": "n_embed",
"dtype": "torch_dtype",
"num_attention_heads": "n_head",
}
def __init__(
@ -134,12 +129,13 @@ class BloomConfig(PretrainedConfig):
hidden_dropout=0.0,
attention_dropout=0.0,
pretraining_tp=1, # TP rank used when training with megatron
dtype="bfloat16",
slow_but_exact=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
@ -152,7 +148,6 @@ class BloomConfig(PretrainedConfig):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.dtype = dtype
self.slow_but_exact = slow_but_exact
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

View File

@ -203,7 +203,8 @@ def convert_bloom_checkpoint_to_pytorch(
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print(f"Save PyTorch model to {pytorch_weights_dump_path}")
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
model = model.to(config.torch_dtype)
torch.save(model.state_dict(), pytorch_weights_dump_path)
print(f"Save configuration file to {pytorch_config_dump_path}")
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: