Fix T5 incorrect weight decay in Trainer and official summarization example (#18002)
* Add ALL_LAYERNORM_LAYERS for LayerNorm * fix bug of appending layer norm
This commit is contained in:
parent
22edb68d49
commit
bf37e5c7f6
|
@ -526,7 +526,7 @@ def main():
|
|||
|
||||
# Optimizer
|
||||
# Split weights in two groups, one with weight decay and the other not.
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
|
|
|
@ -32,7 +32,8 @@ from ...modeling_outputs import (
|
|||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
DUMMY_INPUTS,
|
||||
DUMMY_MASK,
|
||||
|
@ -260,6 +261,8 @@ except Exception:
|
|||
logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
|
||||
pass
|
||||
|
||||
ALL_LAYERNORM_LAYERS.append(LongT5LayerNorm)
|
||||
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
|
||||
class LongT5DenseActDense(nn.Module):
|
||||
|
|
|
@ -34,7 +34,7 @@ from ...modeling_outputs import (
|
|||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
DUMMY_INPUTS,
|
||||
DUMMY_MASK,
|
||||
|
@ -275,6 +275,8 @@ except Exception:
|
|||
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
|
||||
pass
|
||||
|
||||
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
|
||||
|
||||
|
||||
class T5DenseActDense(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
|
|
|
@ -21,6 +21,8 @@ from torch import _softmax_backward_data, nn
|
|||
from .utils import logging
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0")
|
||||
|
|
|
@ -71,6 +71,7 @@ from .modelcard import TrainingSummary
|
|||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
from .optimization import Adafactor, get_scheduler
|
||||
from .pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer_callback import (
|
||||
CallbackHandler,
|
||||
|
@ -967,7 +968,7 @@ class Trainer:
|
|||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if self.optimizer is None:
|
||||
decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
|
||||
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue