[lightning_base] fix s2s logging, only make train_loader once (#6404)
This commit is contained in:
parent
72add6c98f
commit
84c265ffcc
|
@ -150,15 +150,20 @@ class BaseTransformer(pl.LightningModule):
|
|||
def test_epoch_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
def setup(self, step):
|
||||
train_batch_size = self.hparams.train_batch_size
|
||||
dataloader = self.get_dataloader("train", train_batch_size)
|
||||
self.train_loader = dataloader
|
||||
self.total_steps = (
|
||||
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
@property
|
||||
def total_steps(self) -> int:
|
||||
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
|
||||
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
|
||||
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||
dataset_size = len(self.train_loader.dataset)
|
||||
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||
|
||||
def setup(self, mode):
|
||||
if mode == "fit":
|
||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||
|
||||
def get_dataloader(self, type_path, batch_size, shuffle=False):
|
||||
raise NotImplementedError("You must implement this for your task")
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
@ -304,6 +309,13 @@ def add_generic_args(parser, root_dir) -> None:
|
|||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
|
||||
def generic_train(
|
||||
|
|
|
@ -10,14 +10,7 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
from lightning_base import generic_train
|
||||
from transformers import (
|
||||
AdamW,
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
MBartTokenizer,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||
|
||||
|
||||
try:
|
||||
|
@ -158,24 +151,6 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||
)
|
||||
return loss_ce, s_logits_slct, t_logits_slct
|
||||
|
||||
def configure_optimizers(self):
|
||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.hparams.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||
self.opt = optimizer
|
||||
return [optimizer]
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||
|
|
|
@ -3,7 +3,6 @@ import glob
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
@ -14,7 +13,7 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
|
||||
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
|
||||
|
||||
|
||||
try:
|
||||
|
@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer):
|
|||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if max(scheduler.get_last_lr()) > 0:
|
||||
warnings.warn("All learning rates are 0")
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
|
@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer):
|
|||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
|
||||
)
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
|
|
|
@ -75,7 +75,7 @@ class GLUETransformer(BaseTransformer):
|
|||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
def get_dataloader(self, mode: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
"Load datasets. Called after prepare data."
|
||||
|
||||
# We test on dev set to compare to benchmarks without having to submit to GLUE server
|
||||
|
@ -161,13 +161,6 @@ class GLUETransformer(BaseTransformer):
|
|||
type=int,
|
||||
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
|
|
|
@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer):
|
|||
)
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
"Compute validation"
|
||||
|
||||
"""Compute validation""" ""
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if self.config.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
|
@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer):
|
|||
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@ import unittest
|
|||
from unittest.mock import patch
|
||||
|
||||
import run_ner
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -12,6 +13,7 @@ logger = logging.getLogger()
|
|||
|
||||
|
||||
class ExamplesTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_run_ner(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
@ -31,3 +33,23 @@ class ExamplesTests(unittest.TestCase):
|
|||
with patch.object(sys, "argv", ["run.py"] + testargs):
|
||||
result = run_ner.main()
|
||||
self.assertLess(result["eval_loss"], 1.5)
|
||||
|
||||
def test_run_ner_pl(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
testargs = """
|
||||
--model_name distilbert-base-german-cased
|
||||
--output_dir ./tests/fixtures/tests_samples/temp_dir
|
||||
--overwrite_output_dir
|
||||
--data_dir ./tests/fixtures/tests_samples/GermEval
|
||||
--labels ./tests/fixtures/tests_samples/GermEval/labels.txt
|
||||
--max_seq_length 128
|
||||
--num_train_epochs 6
|
||||
--logging_steps 1
|
||||
--do_train
|
||||
--do_eval
|
||||
""".split()
|
||||
with patch.object(sys, "argv", ["run.py"] + testargs):
|
||||
result = run_ner.main()
|
||||
self.assertLess(result["eval_loss"], 1.5)
|
||||
|
|
Loading…
Reference in New Issue