Fix BLOOM dtype (#17995)
* Add fp16 option * Fix BLOOM dtype * Formatting * Remove torch_dtype arg * Revert formatting * Apply formatting * Add n_embed backward compat
This commit is contained in:
parent
981714efe1
commit
bc34c21191
|
@ -78,9 +78,6 @@ class BloomConfig(PretrainedConfig):
|
|||
Dropout rate applied to the attention probs
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
dtype (`str`, *optional*, defaults to `"bfloat16"`):
|
||||
Precision that has been used for the model's training in Megatron. Please load the model in the correct
|
||||
precision by doing `model = BloomModel.from_pretrained(model_name, torch_dtype="auto")`.`
|
||||
pretraining_tp (`int`, *optional*, defaults to `1`):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||
|
@ -114,9 +111,7 @@ class BloomConfig(PretrainedConfig):
|
|||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"num_hidden_layers": "n_layer",
|
||||
"n_head": "num_attention_heads",
|
||||
"hidden_size": "n_embed",
|
||||
"dtype": "torch_dtype",
|
||||
"num_attention_heads": "n_head",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
@ -134,12 +129,13 @@ class BloomConfig(PretrainedConfig):
|
|||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
pretraining_tp=1, # TP rank used when training with megatron
|
||||
dtype="bfloat16",
|
||||
slow_but_exact=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
# Backward compatibility with n_embed kwarg
|
||||
n_embed = kwargs.pop("n_embed", None)
|
||||
self.hidden_size = hidden_size if n_embed is None else n_embed
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
|
@ -152,7 +148,6 @@ class BloomConfig(PretrainedConfig):
|
|||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.dtype = dtype
|
||||
self.slow_but_exact = slow_but_exact
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
|
|
@ -203,7 +203,8 @@ def convert_bloom_checkpoint_to_pytorch(
|
|||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print(f"Save PyTorch model to {pytorch_weights_dump_path}")
|
||||
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
|
||||
model = model.to(config.torch_dtype)
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print(f"Save configuration file to {pytorch_config_dump_path}")
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
|
|
Loading…
Reference in New Issue