[Flax] Example scripts - correct weight decay (#12409)

* fix_torch_device_generate_test

* remove @

* finish

* finish

* correct style
This commit is contained in:
Patrick von Platen 2021-06-29 12:01:08 +01:00 committed by GitHub
parent aecae53377
commit 813328682e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 21 additions and 3 deletions

View File

@ -477,9 +477,15 @@ def main():
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxGPT2.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
flat_mask = {
path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
for path in flat_params
}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer

View File

@ -508,6 +508,9 @@ if __name__ == "__main__":
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBERT-like models.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}

View File

@ -626,7 +626,10 @@ if __name__ == "__main__":
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
flat_mask = {
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
for path in flat_params
}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer

View File

@ -578,9 +578,15 @@ def main():
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBart.
# For FlaxT5, one should correct the layer norm parameter naming
# accordingly - see `run_t5_mlm_flax.py` e.g.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
layer_norm_params = [
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
]
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer