transformers/examples/research_projects/rag/_test_finetune_rag.py

112 lines
3.9 KiB
Python

import json
import logging
import os
import sys
from pathlib import Path
import finetune_rag
from transformers.file_utils import is_apex_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_ray,
require_torch_gpu,
require_torch_multi_gpu,
)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class RagFinetuneExampleTests(TestCasePlus):
def _create_dummy_data(self, data_dir):
os.makedirs(data_dir, exist_ok=True)
contents = {"source": "What is love ?", "target": "life"}
n_lines = {"train": 12, "val": 2, "test": 2}
for split in ["train", "test", "val"]:
for field in ["source", "target"]:
content = "\n".join([contents[field]] * n_lines[split])
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
f.write(content)
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
tmp_dir = self.get_auto_remove_tmp_dir()
output_dir = os.path.join(tmp_dir, "output")
data_dir = os.path.join(tmp_dir, "data")
self._create_dummy_data(data_dir=data_dir)
testargs = f"""
--data_dir {data_dir} \
--output_dir {output_dir} \
--model_name_or_path facebook/rag-sequence-base \
--model_type rag_sequence \
--do_train \
--do_predict \
--n_val -1 \
--val_check_interval 1.0 \
--train_batch_size 2 \
--eval_batch_size 1 \
--max_source_length 25 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-04 \
--num_train_epochs 1 \
--warmup_steps 4 \
--gradient_accumulation_steps 1 \
--distributed-port 8787 \
--use_dummy_dataset 1 \
--distributed_retriever {distributed_retriever} \
""".split()
if gpus > 0:
testargs.append(f"--gpus={gpus}")
if is_apex_available():
testargs.append("--fp16")
else:
testargs.append("--gpus=0")
testargs.append("--distributed_backend=ddp_cpu")
testargs.append("--num_processes=2")
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs
execute_subprocess_async(cmd, env=self.get_env())
metrics_save_path = os.path.join(output_dir, "metrics.json")
with open(metrics_save_path) as f:
result = json.load(f)
return result
@require_torch_gpu
def test_finetune_gpu(self):
result = self._run_finetune(gpus=1)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_multi_gpu
def test_finetune_multigpu(self):
result = self._run_finetune(gpus=2)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_gpu
@require_ray
def test_finetune_gpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_multi_gpu
@require_ray
def test_finetune_multigpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)