650 lines
26 KiB
Python
650 lines
26 KiB
Python
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed as torch_distrib
|
|
from pytorch_lightning.plugins.training_type import DDPPlugin
|
|
from torch.utils.data import DataLoader
|
|
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoTokenizer,
|
|
BartForConditionalGeneration,
|
|
BatchEncoding,
|
|
RagConfig,
|
|
RagSequenceForGeneration,
|
|
RagTokenForGeneration,
|
|
RagTokenizer,
|
|
T5ForConditionalGeneration,
|
|
)
|
|
from transformers import logging as transformers_logging
|
|
from transformers.integrations import is_ray_available
|
|
|
|
|
|
if is_ray_available():
|
|
import ray
|
|
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
|
|
|
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
|
get_checkpoint_callback,
|
|
get_early_stopping_callback,
|
|
Seq2SeqLoggingCallback,
|
|
)
|
|
|
|
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
|
from utils_rag import ( # noqa: E402 # isort:skip
|
|
calculate_exact_match,
|
|
flatten_list,
|
|
get_git_info,
|
|
is_rag_model,
|
|
lmap,
|
|
pickle_save,
|
|
save_git_info,
|
|
save_json,
|
|
set_extra_model_params,
|
|
Seq2SeqDataset,
|
|
)
|
|
|
|
# need the parent dir module
|
|
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
|
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
transformers_logging.set_verbosity_info()
|
|
|
|
|
|
class AttrDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super(AttrDict, self).__init__(*args, **kwargs)
|
|
self.__dict__ = self
|
|
|
|
|
|
class CustomDDP(DDPPlugin):
|
|
def init_ddp_connection(self, global_rank=None, world_size=None) -> None:
|
|
module = self.model
|
|
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
|
|
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
|
|
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
|
|
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
|
|
if not torch.distributed.is_initialized():
|
|
logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
|
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
|
|
|
|
if module.is_rag_model:
|
|
self.distributed_port = module.hparams.distributed_port
|
|
if module.distributed_retriever == "pytorch":
|
|
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
|
elif module.distributed_retriever == "ray" and global_rank == 0:
|
|
# For the Ray retriever, only initialize it once when global
|
|
# rank is 0.
|
|
module.model.rag.retriever.init_retrieval()
|
|
|
|
|
|
class GenerativeQAModule(BaseTransformer):
|
|
mode = "generative_qa"
|
|
loss_names = ["loss"]
|
|
metric_names = ["em"]
|
|
val_metric = "em"
|
|
|
|
def __init__(self, hparams, **kwargs):
|
|
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
|
|
if isinstance(hparams, dict):
|
|
hparams = AttrDict(hparams)
|
|
if hparams.model_type == "rag_sequence":
|
|
self.model_class = RagSequenceForGeneration
|
|
elif hparams.model_type == "rag_token":
|
|
self.model_class = RagTokenForGeneration
|
|
elif hparams.model_type == "bart":
|
|
self.model_class = BartForConditionalGeneration
|
|
else:
|
|
self.model_class = T5ForConditionalGeneration
|
|
self.is_rag_model = is_rag_model(hparams.model_type)
|
|
|
|
config_class = RagConfig if self.is_rag_model else AutoConfig
|
|
config = config_class.from_pretrained(hparams.model_name_or_path)
|
|
|
|
# set retriever parameters
|
|
config.index_name = hparams.index_name or config.index_name
|
|
config.passages_path = hparams.passages_path or config.passages_path
|
|
config.index_path = hparams.index_path or config.index_path
|
|
config.use_dummy_dataset = hparams.use_dummy_dataset
|
|
|
|
# set extra_model_params for generator configs and load_model
|
|
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
|
if self.is_rag_model:
|
|
if hparams.prefix is not None:
|
|
config.generator.prefix = hparams.prefix
|
|
config.label_smoothing = hparams.label_smoothing
|
|
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
|
if hparams.distributed_retriever == "pytorch":
|
|
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
|
elif hparams.distributed_retriever == "ray":
|
|
# The Ray retriever needs the handles to the retriever actors.
|
|
retriever = RagRayDistributedRetriever.from_pretrained(
|
|
hparams.model_name_or_path, hparams.actor_handles, config=config
|
|
)
|
|
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
|
prefix = config.question_encoder.prefix
|
|
else:
|
|
if hparams.prefix is not None:
|
|
config.prefix = hparams.prefix
|
|
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
|
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
|
prefix = config.prefix
|
|
|
|
tokenizer = (
|
|
RagTokenizer.from_pretrained(hparams.model_name_or_path)
|
|
if self.is_rag_model
|
|
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
|
|
)
|
|
|
|
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
|
|
|
|
save_git_info(self.hparams.output_dir)
|
|
self.output_dir = Path(self.hparams.output_dir)
|
|
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
|
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
|
pickle_save(self.hparams, self.hparams_save_path)
|
|
self.step_count = 0
|
|
self.metrics = defaultdict(list)
|
|
|
|
self.dataset_kwargs: dict = {
|
|
"data_dir": self.hparams.data_dir,
|
|
"max_source_length": self.hparams.max_source_length,
|
|
"prefix": prefix or "",
|
|
}
|
|
n_observations_per_split = {
|
|
"train": self.hparams.n_train,
|
|
"val": self.hparams.n_val,
|
|
"test": self.hparams.n_test,
|
|
}
|
|
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
|
|
|
self.target_lens = {
|
|
"train": self.hparams.max_target_length,
|
|
"val": self.hparams.val_max_target_length,
|
|
"test": self.hparams.test_max_target_length,
|
|
}
|
|
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
|
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
|
|
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
|
self.num_workers = hparams.num_workers
|
|
self.distributed_port = self.hparams.distributed_port
|
|
|
|
# For single GPU training, init_ddp_connection is not called.
|
|
# So we need to initialize the retrievers here.
|
|
if hparams.gpus <= 1:
|
|
if hparams.distributed_retriever == "ray":
|
|
self.model.retriever.init_retrieval()
|
|
elif hparams.distributed_retriever == "pytorch":
|
|
self.model.retriever.init_retrieval(self.distributed_port)
|
|
|
|
self.distributed_retriever = hparams.distributed_retriever
|
|
|
|
def forward(self, input_ids, **kwargs):
|
|
return self.model(input_ids, **kwargs)
|
|
|
|
def ids_to_clean_text(self, generated_ids: List[int]):
|
|
gen_text = self.tokenizer.batch_decode(
|
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
|
)
|
|
return lmap(str.strip, gen_text)
|
|
|
|
def _step(self, batch: dict) -> Tuple:
|
|
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
|
|
|
rag_kwargs = {}
|
|
if isinstance(self.model, T5ForConditionalGeneration):
|
|
decoder_input_ids = self.model._shift_right(target_ids)
|
|
lm_labels = target_ids
|
|
elif isinstance(self.model, BartForConditionalGeneration):
|
|
decoder_input_ids = target_ids[:, :-1].contiguous()
|
|
lm_labels = target_ids[:, 1:].clone()
|
|
else:
|
|
assert self.is_rag_model
|
|
generator = self.model.rag.generator
|
|
if isinstance(generator, T5ForConditionalGeneration):
|
|
decoder_start_token_id = generator.config.decoder_start_token_id
|
|
decoder_input_ids = (
|
|
torch.cat(
|
|
[torch.tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
|
|
dim=1,
|
|
)
|
|
if target_ids.shape[0] < self.target_lens["train"]
|
|
else generator._shift_right(target_ids)
|
|
)
|
|
elif isinstance(generator, BartForConditionalGeneration):
|
|
decoder_input_ids = target_ids
|
|
lm_labels = decoder_input_ids
|
|
rag_kwargs["reduce_loss"] = True
|
|
|
|
assert decoder_input_ids is not None
|
|
|
|
outputs = self(
|
|
source_ids,
|
|
attention_mask=source_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
use_cache=False,
|
|
labels=lm_labels,
|
|
**rag_kwargs,
|
|
)
|
|
|
|
loss = outputs["loss"]
|
|
return (loss,)
|
|
|
|
@property
|
|
def pad(self) -> int:
|
|
raise NotImplementedError("pad not implemented")
|
|
|
|
def training_step(self, batch, batch_idx) -> Dict:
|
|
loss_tensors = self._step(batch)
|
|
|
|
logs = {name: loss.detach() for name, loss in zip(self.loss_names, loss_tensors)}
|
|
# tokens per batch
|
|
tgt_pad_token_id = (
|
|
self.tokenizer.generator.pad_token_id
|
|
if isinstance(self.tokenizer, RagTokenizer)
|
|
else self.tokenizer.pad_token_id
|
|
)
|
|
src_pad_token_id = (
|
|
self.tokenizer.question_encoder.pad_token_id
|
|
if isinstance(self.tokenizer, RagTokenizer)
|
|
else self.tokenizer.pad_token_id
|
|
)
|
|
logs["tpb"] = (
|
|
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
|
|
)
|
|
|
|
return {"loss": loss_tensors[0], "log": logs}
|
|
|
|
def validation_step(self, batch, batch_idx) -> Dict:
|
|
return self._generative_step(batch)
|
|
|
|
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
|
self.step_count += 1
|
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
|
loss = losses["loss"]
|
|
gen_metrics = {
|
|
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
|
}
|
|
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
|
|
gen_metrics.update({k: v.item() for k, v in losses.items()})
|
|
|
|
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
|
|
if dist.is_initialized():
|
|
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
|
|
metrics_tensor = metrics_tensor / dist.get_world_size()
|
|
gen_metrics.update({self.val_metric: metrics_tensor.item()})
|
|
|
|
losses.update(gen_metrics)
|
|
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
|
metrics["step_count"] = self.step_count
|
|
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
|
preds = flatten_list([x["preds"] for x in outputs])
|
|
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor}
|
|
|
|
def save_metrics(self, latest_metrics, type_path) -> None:
|
|
self.metrics[type_path].append(latest_metrics)
|
|
save_json(self.metrics, self.metrics_save_path)
|
|
|
|
def calc_generative_metrics(self, preds, target) -> Dict:
|
|
return calculate_exact_match(preds, target)
|
|
|
|
def _generative_step(self, batch: dict) -> dict:
|
|
start_time = time.time()
|
|
batch = BatchEncoding(batch).to(device=self.model.device)
|
|
generated_ids = self.model.generate(
|
|
batch["input_ids"],
|
|
attention_mask=batch["attention_mask"],
|
|
do_deduplication=False, # rag specific parameter
|
|
use_cache=True,
|
|
min_length=1,
|
|
max_length=self.target_lens["val"],
|
|
)
|
|
|
|
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
|
|
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
|
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
|
loss_tensors = self._step(batch)
|
|
base_metrics = dict(zip(self.loss_names, loss_tensors))
|
|
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
|
|
|
|
summ_len = np.mean(lmap(len, generated_ids))
|
|
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
|
|
return base_metrics
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
return self._generative_step(batch)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
return self.validation_epoch_end(outputs, prefix="test")
|
|
|
|
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
|
n_obs = self.n_obs[type_path]
|
|
max_target_length = self.target_lens[type_path]
|
|
dataset = Seq2SeqDataset(
|
|
self.tokenizer,
|
|
type_path=type_path,
|
|
n_obs=n_obs,
|
|
max_target_length=max_target_length,
|
|
**self.dataset_kwargs,
|
|
)
|
|
return dataset
|
|
|
|
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
|
dataset = self.get_dataset(type_path)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=dataset.collate_fn,
|
|
shuffle=shuffle,
|
|
num_workers=self.num_workers,
|
|
)
|
|
return dataloader
|
|
|
|
def train_dataloader(self) -> DataLoader:
|
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
|
return dataloader
|
|
|
|
def val_dataloader(self) -> DataLoader:
|
|
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
|
|
|
def test_dataloader(self) -> DataLoader:
|
|
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
|
|
|
@pl.utilities.rank_zero_only
|
|
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
|
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
|
|
self.model.config.save_step = self.step_count
|
|
self.model.save_pretrained(save_path)
|
|
self.tokenizer.save_pretrained(save_path)
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parser, root_dir):
|
|
BaseTransformer.add_model_specific_args(parser, root_dir)
|
|
add_generic_args(parser, root_dir)
|
|
parser.add_argument(
|
|
"--max_source_length",
|
|
default=128,
|
|
type=int,
|
|
help=(
|
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--max_target_length",
|
|
default=25,
|
|
type=int,
|
|
help=(
|
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--val_max_target_length",
|
|
default=25,
|
|
type=int,
|
|
help=(
|
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--test_max_target_length",
|
|
default=25,
|
|
type=int,
|
|
help=(
|
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
),
|
|
)
|
|
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
|
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
|
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
|
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
|
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
|
parser.add_argument(
|
|
"--prefix",
|
|
type=str,
|
|
default=None,
|
|
help="Prefix added at the beginning of each text, typically used with T5-based models.",
|
|
)
|
|
parser.add_argument(
|
|
"--early_stopping_patience",
|
|
type=int,
|
|
default=-1,
|
|
required=False,
|
|
help=(
|
|
"-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
|
|
" val_check_interval will effect it."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
|
|
)
|
|
parser.add_argument(
|
|
"--model_type",
|
|
choices=["rag_sequence", "rag_token", "bart", "t5"],
|
|
type=str,
|
|
help=(
|
|
"RAG model type: sequence or token, if none specified, the type is inferred from the"
|
|
" model_name_or_path"
|
|
),
|
|
)
|
|
return parser
|
|
|
|
@staticmethod
|
|
def add_retriever_specific_args(parser):
|
|
parser.add_argument(
|
|
"--index_name",
|
|
type=str,
|
|
default=None,
|
|
help=(
|
|
"Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom'"
|
|
" for a local index, or 'legacy' for the orignal one)"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--passages_path",
|
|
type=str,
|
|
default=None,
|
|
help=(
|
|
"Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever"
|
|
" documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--index_path",
|
|
type=str,
|
|
default=None,
|
|
help=(
|
|
"Path to the faiss index for custom index. More info about custom indexes in the RagRetriever"
|
|
" documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--distributed_retriever",
|
|
choices=["ray", "pytorch"],
|
|
type=str,
|
|
default="pytorch",
|
|
help=(
|
|
"What implementation to use for distributed retriever? If "
|
|
"pytorch is selected, the index is loaded on training "
|
|
"worker 0, and torch.distributed is used to handle "
|
|
"communication between training worker 0, and the other "
|
|
"training workers. If ray is selected, the Ray library is "
|
|
"used to create load the index on separate processes, "
|
|
"and Ray handles the communication between the training "
|
|
"workers and the retrieval actors."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--use_dummy_dataset",
|
|
type=bool,
|
|
default=False,
|
|
help=(
|
|
"Whether to use the dummy version of the dataset index. More info about custom indexes in the"
|
|
" RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
|
|
),
|
|
)
|
|
return parser
|
|
|
|
@staticmethod
|
|
def add_ray_specific_args(parser):
|
|
# Ray cluster address.
|
|
parser.add_argument(
|
|
"--ray-address",
|
|
default="auto",
|
|
type=str,
|
|
help=(
|
|
"The address of the Ray cluster to connect to. If not "
|
|
"specified, Ray will attempt to automatically detect the "
|
|
"cluster. Has no effect if pytorch is used as the distributed "
|
|
"retriever."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--num_retrieval_workers",
|
|
type=int,
|
|
default=1,
|
|
help=(
|
|
"The number of retrieval actors to use when Ray is selected "
|
|
"for the distributed retriever. Has no effect when "
|
|
"distributed_retriever is set to pytorch."
|
|
),
|
|
)
|
|
return parser
|
|
|
|
|
|
def main(args=None, model=None) -> GenerativeQAModule:
|
|
parser = argparse.ArgumentParser()
|
|
parser = pl.Trainer.add_argparse_args(parser)
|
|
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
|
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
|
|
|
args = args or parser.parse_args()
|
|
|
|
Path(args.output_dir).mkdir(exist_ok=True)
|
|
|
|
named_actors = []
|
|
if args.distributed_retriever == "ray" and args.gpus > 1:
|
|
if not is_ray_available():
|
|
raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
|
|
# Connect to an existing Ray cluster.
|
|
try:
|
|
ray.init(address=args.ray_address, namespace="rag")
|
|
except (ConnectionError, ValueError):
|
|
logger.warning(
|
|
"Connection to Ray cluster failed. Make sure a Ray "
|
|
"cluster is running by either using Ray's cluster "
|
|
"launcher (`ray up`) or by manually starting Ray on "
|
|
"each node via `ray start --head` for the head node "
|
|
"and `ray start --address='<ip address>:6379'` for "
|
|
"additional nodes. See "
|
|
"https://docs.ray.io/en/master/cluster/index.html "
|
|
"for more info."
|
|
)
|
|
raise
|
|
|
|
# Create Ray actors only for rank 0.
|
|
if ("LOCAL_RANK" not in os.environ or int(os.environ["LOCAL_RANK"]) == 0) and (
|
|
"NODE_RANK" not in os.environ or int(os.environ["NODE_RANK"]) == 0
|
|
):
|
|
remote_cls = ray.remote(RayRetriever)
|
|
named_actors = [
|
|
remote_cls.options(name="retrieval_worker_{}".format(i)).remote()
|
|
for i in range(args.num_retrieval_workers)
|
|
]
|
|
else:
|
|
logger.info(
|
|
"Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
|
|
os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]
|
|
)
|
|
)
|
|
named_actors = [ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers)]
|
|
args.actor_handles = named_actors
|
|
assert args.actor_handles == named_actors
|
|
|
|
if model is None:
|
|
model: GenerativeQAModule = GenerativeQAModule(args)
|
|
|
|
dataset = Path(args.data_dir).name
|
|
if (
|
|
args.logger_name == "default"
|
|
or args.fast_dev_run
|
|
or str(args.output_dir).startswith("/tmp")
|
|
or str(args.output_dir).startswith("/var")
|
|
):
|
|
training_logger = True # don't pollute wandb logs unnecessarily
|
|
elif args.logger_name == "wandb":
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
|
|
project = os.environ.get("WANDB_PROJECT", dataset)
|
|
training_logger = WandbLogger(name=model.output_dir.name, project=project)
|
|
|
|
elif args.logger_name == "wandb_shared":
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
|
|
training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
|
|
|
es_callback = (
|
|
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
|
if args.early_stopping_patience >= 0
|
|
else False
|
|
)
|
|
|
|
trainer: pl.Trainer = generic_train(
|
|
model,
|
|
args,
|
|
logging_callback=Seq2SeqLoggingCallback(),
|
|
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
|
early_stopping_callback=es_callback,
|
|
logger=training_logger,
|
|
custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None,
|
|
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
|
)
|
|
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
|
|
|
if not args.do_predict:
|
|
return model
|
|
|
|
# test() without a model tests using the best checkpoint automatically
|
|
trainer.test()
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser = pl.Trainer.add_argparse_args(parser)
|
|
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
|
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
|
parser = GenerativeQAModule.add_ray_specific_args(parser)
|
|
|
|
# Pytorch Lightning Profiler
|
|
parser.add_argument(
|
|
"--profile",
|
|
action="store_true",
|
|
help="If True, use pytorch_lightning.profiler.AdvancedProfiler to profile the Trainer.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|