[lightning_base] fix s2s logging, only make train_loader once (#6404)

This commit is contained in:
Sam Shleifer 2020-08-16 22:49:41 -04:00 committed by GitHub
parent 72add6c98f
commit 84c265ffcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 72 deletions

View File

@ -150,15 +150,20 @@ class BaseTransformer(pl.LightningModule):
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
def setup(self, step):
train_batch_size = self.hparams.train_batch_size
dataloader = self.get_dataloader("train", train_batch_size)
self.train_loader = dataloader
self.total_steps = (
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
@property
def total_steps(self) -> int:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
dataset_size = len(self.train_loader.dataset)
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "fit":
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
def get_dataloader(self, type_path, batch_size, shuffle=False):
raise NotImplementedError("You must implement this for your task")
def train_dataloader(self):
return self.train_loader
@ -304,6 +309,13 @@ def add_generic_args(parser, root_dir) -> None:
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
def generic_train(

View File

@ -10,14 +10,7 @@ from torch import nn
from torch.nn import functional as F
from lightning_base import generic_train
from transformers import (
AdamW,
BartConfig,
BartForConditionalGeneration,
MBartTokenizer,
T5Config,
T5ForConditionalGeneration,
)
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
try:
@ -158,24 +151,6 @@ class BartSummarizationDistiller(SummarizationModule):
)
return loss_ce, s_logits_slct, t_logits_slct
def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)"
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer
return [optimizer]
@staticmethod
def add_model_specific_args(parser, root_dir):
SummarizationModule.add_model_specific_args(parser, root_dir)

View File

@ -3,7 +3,6 @@ import glob
import logging
import os
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@ -14,7 +13,7 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
try:
@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer):
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self) -> DataLoader:
@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer):
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)

View File

@ -75,7 +75,7 @@ class GLUETransformer(BaseTransformer):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader:
def get_dataloader(self, mode: str, batch_size: int, shuffle: bool = False) -> DataLoader:
"Load datasets. Called after prepare data."
# We test on dev set to compare to benchmarks without having to submit to GLUE server
@ -161,13 +161,6 @@ class GLUETransformer(BaseTransformer):
type=int,
help="The number of GPUs allocated for this, it is by default 0 meaning none",
)
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"

View File

@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer):
)
def validation_step(self, batch, batch_nb):
"Compute validation"
"""Compute validation""" ""
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
if self.config.model_type != "distilbert":
inputs["token_type_ids"] = (
@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer):
help="The number of GPUs allocated for this, it is by default 0 meaning none",
)
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)

View File

@ -4,6 +4,7 @@ import unittest
from unittest.mock import patch
import run_ner
from transformers.testing_utils import slow
logging.basicConfig(level=logging.INFO)
@ -12,6 +13,7 @@ logger = logging.getLogger()
class ExamplesTests(unittest.TestCase):
@slow
def test_run_ner(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@ -31,3 +33,23 @@ class ExamplesTests(unittest.TestCase):
with patch.object(sys, "argv", ["run.py"] + testargs):
result = run_ner.main()
self.assertLess(result["eval_loss"], 1.5)
def test_run_ner_pl(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
testargs = """
--model_name distilbert-base-german-cased
--output_dir ./tests/fixtures/tests_samples/temp_dir
--overwrite_output_dir
--data_dir ./tests/fixtures/tests_samples/GermEval
--labels ./tests/fixtures/tests_samples/GermEval/labels.txt
--max_seq_length 128
--num_train_epochs 6
--logging_steps 1
--do_train
--do_eval
""".split()
with patch.object(sys, "argv", ["run.py"] + testargs):
result = run_ner.main()
self.assertLess(result["eval_loss"], 1.5)