Skywork/train/train.py

455 lines
17 KiB
Python

#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import logging
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional, List
from pathlib import Path
import datasets
import torch
from build_dataset import fault_tolerance_data_collator, build_instruction_dataset, DataCollatorForSupervisedDataset
import transformers
from transformers import Trainer, GPTQConfig, deepspeed
from transformers import (
CONFIG_MAPPING,
AutoConfig,
AutoModelForCausalLM,
BitsAndBytesConfig,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import sys
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": (
"Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
)
},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
sft_dataset_dir: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
data_cache_dir: Optional[str] = field(default=None, metadata={"help": "The datasets processed stored"})
max_train_samples: Optional[int] = field(default=None, metadata={"help": "maximum number of train samples"})
max_seq_length: Optional[int] = field(default=1024)
@dataclass
class TrainingArguments(TrainingArguments):
load_in_kbits: Optional[int] = field(default=16)
report_to: Optional[str] = field(default='none')
run_name: Optional[str] = field(default='project_name')
use_lora: Optional[bool] = False
task_type: Optional[str] = field(default='pt')
# cache_dir: Optional[str] = field(default='model_cache')
@dataclass
class LoraArguments:
lora_r: Optional[int] = 64
lora_alpha: Optional[int] = 16
lora_dropout: Optional[float] = 0.05
lora_target_modules: Optional[List[str]] = field(
default_factory=lambda: ["q_proj", "v_proj"]
)
lora_weight_path: Optional[str] = ""
lora_bias: Optional[str] = field(default='none')
modules_to_save: Optional[str] = field(
default_factory=lambda: ["embed_tokens", "lm_head"]
)
use_q_lora: Optional[bool] = False
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
"""Collects the state dict and dump to disk."""
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
else:
if trainer.args.use_lora:
state_dict = get_peft_state_maybe_zero_3(
trainer.model.named_parameters(), bias
)
else:
state_dict = trainer.model.state_dict()
if trainer.args.should_save and trainer.args.local_rank == 0:
trainer._save(output_dir, state_dict=state_dict)
logger = logging.getLogger(__name__)
def main():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.use_q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP or ZeRO3 are not incompatible with QLoRA."
)
send_example_telemetry("run_pt", model_args, data_args)
# Setup logging
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO, # if training_args.local_rank in [-1, 0] else logging.WARN,
handlers=[logging.StreamHandler(sys.stdout)],)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path, use_fast=False, trust_remote_code=True)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0
eval_dataset=None
train_dataset = None
if training_args.do_train:
if training_args.task_type == "pt":
train_dataset = datasets.load_from_disk(data_args.data_cache_dir, keep_in_memory=False)["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
logger.info(f"Start shuffling training dataset.")
print("train_dataset", train_dataset)
train_dataset = train_dataset.shuffle(seed=training_args.seed)
logger.info(f"shuffle successively!")
logger.info(f"Num train_samples {len(train_dataset)}")
logger.info("Training example:")
logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
elif training_args.task_type == "sft":
path = Path(data_args.sft_dataset_dir)
files = [os.path.join(path, jsonl_file.name) for jsonl_file in path.glob("*.jsonl")]
logger.info(f"SFT Training files: {' '.join(files)}")
train_dataset = build_instruction_dataset(
data_path=files,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
data_cache_dir = data_args.data_cache_dir,
preprocessing_num_workers = data_args.preprocessing_num_workers)
logger.info(f"Num train_samples {len(train_dataset)}")
logger.info("Training example:")
logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
else:
raise ValueError(f"task_type must be either sft or pt, but found{training_args.task_type}")
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
print(f"torch_dtype: {torch_dtype}, compute_dtype: {compute_dtype}")
if training_args.load_in_kbits in [4, 8]:
load_in_4bit = training_args.load_in_kbits == 4
load_in_8bit = training_args.load_in_kbits == 8
if training_args.modules_to_save is not None:
load_in_8bit_skip_modules = training_args.modules_to_save.split(',')
else:
load_in_8bit_skip_modules = None
quantization_config = BitsAndBytesConfig(
load_in_4bit=training_args.load_in_kbits == 4,
load_in_8bit=training_args.load_in_kbits == 8,
llm_int8_threshold=6.0,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
else:
load_in_4bit = False
load_in_8bit = False
quantization_config = None
if quantization_config is not None:
logger.info(f"quantization_config:{quantization_config.to_dict()}")
# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=model_args.cache_dir,
trust_remote_code=True,
quantization_config=GPTQConfig(
bits=4, disable_exllama=True
)
if training_args.use_lora and lora_args.use_q_lora
else None,
)
model_vocab_size = model.get_input_embeddings().weight.shape[0]
logger.info(f"Model vocab size: {model_vocab_size}")
logger.info(f"len(tokenizer):{len(tokenizer)}")
if model_vocab_size != len(tokenizer):
logger.info(f"Resize model vocab size to {len(tokenizer)}")
model.resize_token_embeddings(len(tokenizer))
if training_args.use_lora:
if lora_args.use_q_lora or 'chat' in model_args.model_name_or_path.lower():
modules_to_save = None
else:
modules_to_save = lora_args.modules_to_save
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type="CAUSAL_LM",
modules_to_save=modules_to_save # This argument serves for adding new tokens.
)
if lora_args.use_q_lora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
model = get_peft_model(model, lora_config)
# Print peft trainable params
model.print_trainable_parameters()
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
model.config.use_cache = False
data_collator = fault_tolerance_data_collator if training_args.task_type == "pt" else DataCollatorForSupervisedDataset(tokenizer=tokenizer)
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)
# Evaluation
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] =len(eval_dataset)
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if __name__ == "__main__":
main()