455 lines
17 KiB
Python
455 lines
17 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
|
|
|
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
|
https://huggingface.co/models?filter=text-generation
|
|
"""
|
|
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
|
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, List
|
|
from pathlib import Path
|
|
import datasets
|
|
import torch
|
|
from build_dataset import fault_tolerance_data_collator, build_instruction_dataset, DataCollatorForSupervisedDataset
|
|
import transformers
|
|
from transformers import Trainer, GPTQConfig, deepspeed
|
|
from transformers import (
|
|
CONFIG_MAPPING,
|
|
AutoConfig,
|
|
AutoModelForCausalLM,
|
|
BitsAndBytesConfig,
|
|
AutoTokenizer,
|
|
HfArgumentParser,
|
|
Trainer,
|
|
TrainingArguments,
|
|
set_seed,
|
|
)
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from transformers.utils import send_example_telemetry
|
|
from transformers.utils.versions import require_version
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
|
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
|
import sys
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
"""
|
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
|
"""
|
|
|
|
model_name_or_path: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
|
)
|
|
},
|
|
)
|
|
tokenizer_name_or_path: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
|
|
)
|
|
},
|
|
)
|
|
|
|
config_overrides: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"Override some existing default config settings when a model is trained from scratch. Example: "
|
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
|
)
|
|
},
|
|
)
|
|
config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
)
|
|
cache_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
|
)
|
|
use_fast_tokenizer: bool = field(
|
|
default=True,
|
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
|
)
|
|
model_revision: str = field(
|
|
default="main",
|
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
|
)
|
|
use_auth_token: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": (
|
|
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
|
|
"with private models)."
|
|
)
|
|
},
|
|
)
|
|
torch_dtype: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
|
"dtype will be automatically derived from the model's weights."
|
|
),
|
|
"choices": ["auto", "bfloat16", "float16", "float32"],
|
|
},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
|
raise ValueError(
|
|
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataArguments:
|
|
"""
|
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
|
"""
|
|
|
|
sft_dataset_dir: Optional[str] = field(
|
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
|
)
|
|
|
|
overwrite_cache: bool = field(
|
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
|
)
|
|
|
|
preprocessing_num_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
|
)
|
|
|
|
data_cache_dir: Optional[str] = field(default=None, metadata={"help": "The datasets processed stored"})
|
|
max_train_samples: Optional[int] = field(default=None, metadata={"help": "maximum number of train samples"})
|
|
max_seq_length: Optional[int] = field(default=1024)
|
|
|
|
@dataclass
|
|
class TrainingArguments(TrainingArguments):
|
|
load_in_kbits: Optional[int] = field(default=16)
|
|
report_to: Optional[str] = field(default='none')
|
|
run_name: Optional[str] = field(default='project_name')
|
|
use_lora: Optional[bool] = False
|
|
task_type: Optional[str] = field(default='pt')
|
|
# cache_dir: Optional[str] = field(default='model_cache')
|
|
|
|
@dataclass
|
|
class LoraArguments:
|
|
lora_r: Optional[int] = 64
|
|
lora_alpha: Optional[int] = 16
|
|
lora_dropout: Optional[float] = 0.05
|
|
lora_target_modules: Optional[List[str]] = field(
|
|
default_factory=lambda: ["q_proj", "v_proj"]
|
|
)
|
|
lora_weight_path: Optional[str] = ""
|
|
lora_bias: Optional[str] = field(default='none')
|
|
modules_to_save: Optional[str] = field(
|
|
default_factory=lambda: ["embed_tokens", "lm_head"]
|
|
)
|
|
use_q_lora: Optional[bool] = False
|
|
|
|
|
|
# Borrowed from peft.utils.get_peft_model_state_dict
|
|
def get_peft_state_maybe_zero_3(named_params, bias):
|
|
if bias == "none":
|
|
to_return = {k: t for k, t in named_params if "lora_" in k}
|
|
elif bias == "all":
|
|
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
|
elif bias == "lora_only":
|
|
to_return = {}
|
|
maybe_lora_bias = {}
|
|
lora_bias_names = set()
|
|
for k, t in named_params:
|
|
if "lora_" in k:
|
|
to_return[k] = t
|
|
bias_name = k.split("lora_")[0] + "bias"
|
|
lora_bias_names.add(bias_name)
|
|
elif "bias" in k:
|
|
maybe_lora_bias[k] = t
|
|
for k, t in maybe_lora_bias:
|
|
if bias_name in lora_bias_names:
|
|
to_return[bias_name] = t
|
|
else:
|
|
raise NotImplementedError
|
|
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
|
|
return to_return
|
|
|
|
|
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
|
|
"""Collects the state dict and dump to disk."""
|
|
# check if zero3 mode enabled
|
|
if deepspeed.is_deepspeed_zero3_enabled():
|
|
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
|
|
else:
|
|
if trainer.args.use_lora:
|
|
state_dict = get_peft_state_maybe_zero_3(
|
|
trainer.model.named_parameters(), bias
|
|
)
|
|
else:
|
|
state_dict = trainer.model.state_dict()
|
|
if trainer.args.should_save and trainer.args.local_rank == 0:
|
|
trainer._save(output_dir, state_dict=state_dict)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def main():
|
|
|
|
parser = transformers.HfArgumentParser(
|
|
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
|
|
)
|
|
|
|
(
|
|
model_args,
|
|
data_args,
|
|
training_args,
|
|
lora_args,
|
|
) = parser.parse_args_into_dataclasses()
|
|
|
|
device_map = "auto"
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
ddp = world_size != 1
|
|
if lora_args.use_q_lora:
|
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
|
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
|
|
logging.warning(
|
|
"FSDP or ZeRO3 are not incompatible with QLoRA."
|
|
)
|
|
|
|
send_example_telemetry("run_pt", model_args, data_args)
|
|
|
|
# Setup logging
|
|
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO, # if training_args.local_rank in [-1, 0] else logging.WARN,
|
|
handlers=[logging.StreamHandler(sys.stdout)],)
|
|
|
|
|
|
if training_args.should_log:
|
|
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
|
transformers.utils.logging.set_verbosity_info()
|
|
|
|
log_level = training_args.get_process_log_level()
|
|
logger.setLevel(log_level)
|
|
datasets.utils.logging.set_verbosity(log_level)
|
|
transformers.utils.logging.set_verbosity(log_level)
|
|
transformers.utils.logging.enable_default_handler()
|
|
transformers.utils.logging.enable_explicit_format()
|
|
|
|
# Log on each process the small summary:
|
|
logger.warning(
|
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
|
|
)
|
|
|
|
# Detecting last checkpoint.
|
|
last_checkpoint = None
|
|
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
|
raise ValueError(
|
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
|
"Use --overwrite_output_dir to overcome."
|
|
)
|
|
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
|
logger.info(
|
|
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
|
)
|
|
|
|
# Set seed before initializing model.
|
|
set_seed(training_args.seed)
|
|
|
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path, use_fast=False, trust_remote_code=True)
|
|
|
|
if tokenizer.pad_token_id is None:
|
|
if tokenizer.eos_token_id is not None:
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
else:
|
|
tokenizer.pad_token_id = 0
|
|
|
|
eval_dataset=None
|
|
train_dataset = None
|
|
|
|
if training_args.do_train:
|
|
if training_args.task_type == "pt":
|
|
train_dataset = datasets.load_from_disk(data_args.data_cache_dir, keep_in_memory=False)["train"]
|
|
if data_args.max_train_samples is not None:
|
|
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
|
train_dataset = train_dataset.select(range(max_train_samples))
|
|
logger.info(f"Start shuffling training dataset.")
|
|
print("train_dataset", train_dataset)
|
|
train_dataset = train_dataset.shuffle(seed=training_args.seed)
|
|
logger.info(f"shuffle successively!")
|
|
logger.info(f"Num train_samples {len(train_dataset)}")
|
|
logger.info("Training example:")
|
|
logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
|
|
|
|
elif training_args.task_type == "sft":
|
|
path = Path(data_args.sft_dataset_dir)
|
|
files = [os.path.join(path, jsonl_file.name) for jsonl_file in path.glob("*.jsonl")]
|
|
logger.info(f"SFT Training files: {' '.join(files)}")
|
|
train_dataset = build_instruction_dataset(
|
|
data_path=files,
|
|
tokenizer=tokenizer,
|
|
max_seq_length=data_args.max_seq_length,
|
|
data_cache_dir = data_args.data_cache_dir,
|
|
preprocessing_num_workers = data_args.preprocessing_num_workers)
|
|
logger.info(f"Num train_samples {len(train_dataset)}")
|
|
logger.info("Training example:")
|
|
logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
|
|
else:
|
|
raise ValueError(f"task_type must be either sft or pt, but found{training_args.task_type}")
|
|
|
|
torch_dtype = (
|
|
model_args.torch_dtype
|
|
if model_args.torch_dtype in ["auto", None]
|
|
else getattr(torch, model_args.torch_dtype)
|
|
)
|
|
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
|
print(f"torch_dtype: {torch_dtype}, compute_dtype: {compute_dtype}")
|
|
|
|
if training_args.load_in_kbits in [4, 8]:
|
|
load_in_4bit = training_args.load_in_kbits == 4
|
|
load_in_8bit = training_args.load_in_kbits == 8
|
|
if training_args.modules_to_save is not None:
|
|
load_in_8bit_skip_modules = training_args.modules_to_save.split(',')
|
|
else:
|
|
load_in_8bit_skip_modules = None
|
|
quantization_config = BitsAndBytesConfig(
|
|
load_in_4bit=training_args.load_in_kbits == 4,
|
|
load_in_8bit=training_args.load_in_kbits == 8,
|
|
llm_int8_threshold=6.0,
|
|
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
|
|
bnb_4bit_compute_dtype=compute_dtype,
|
|
bnb_4bit_use_double_quant=training_args.double_quant,
|
|
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
|
|
)
|
|
else:
|
|
load_in_4bit = False
|
|
load_in_8bit = False
|
|
quantization_config = None
|
|
if quantization_config is not None:
|
|
logger.info(f"quantization_config:{quantization_config.to_dict()}")
|
|
|
|
# Load model and tokenizer
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
config=config,
|
|
cache_dir=model_args.cache_dir,
|
|
trust_remote_code=True,
|
|
quantization_config=GPTQConfig(
|
|
bits=4, disable_exllama=True
|
|
)
|
|
if training_args.use_lora and lora_args.use_q_lora
|
|
else None,
|
|
)
|
|
|
|
model_vocab_size = model.get_input_embeddings().weight.shape[0]
|
|
logger.info(f"Model vocab size: {model_vocab_size}")
|
|
logger.info(f"len(tokenizer):{len(tokenizer)}")
|
|
if model_vocab_size != len(tokenizer):
|
|
logger.info(f"Resize model vocab size to {len(tokenizer)}")
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
if training_args.use_lora:
|
|
if lora_args.use_q_lora or 'chat' in model_args.model_name_or_path.lower():
|
|
modules_to_save = None
|
|
else:
|
|
modules_to_save = lora_args.modules_to_save
|
|
lora_config = LoraConfig(
|
|
r=lora_args.lora_r,
|
|
lora_alpha=lora_args.lora_alpha,
|
|
target_modules=lora_args.lora_target_modules,
|
|
lora_dropout=lora_args.lora_dropout,
|
|
bias=lora_args.lora_bias,
|
|
task_type="CAUSAL_LM",
|
|
modules_to_save=modules_to_save # This argument serves for adding new tokens.
|
|
)
|
|
if lora_args.use_q_lora:
|
|
model = prepare_model_for_kbit_training(
|
|
model, use_gradient_checkpointing=training_args.gradient_checkpointing
|
|
)
|
|
|
|
model = get_peft_model(model, lora_config)
|
|
|
|
|
|
# Print peft trainable params
|
|
model.print_trainable_parameters()
|
|
|
|
if training_args.gradient_checkpointing:
|
|
model.enable_input_require_grads()
|
|
|
|
model.config.use_cache = False
|
|
data_collator = fault_tolerance_data_collator if training_args.task_type == "pt" else DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
|
# Initialize our Trainer
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
tokenizer=tokenizer,
|
|
data_collator=data_collator
|
|
)
|
|
|
|
# Training
|
|
if training_args.do_train:
|
|
checkpoint = None
|
|
if training_args.resume_from_checkpoint is not None:
|
|
checkpoint = training_args.resume_from_checkpoint
|
|
elif last_checkpoint is not None:
|
|
checkpoint = last_checkpoint
|
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
|
|
|
metrics = train_result.metrics
|
|
|
|
metrics["train_samples"] = len(train_dataset)
|
|
|
|
trainer.log_metrics("train", metrics)
|
|
trainer.save_metrics("train", metrics)
|
|
trainer.save_state()
|
|
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)
|
|
|
|
# Evaluation
|
|
if training_args.do_eval:
|
|
logger.info("*** Evaluate ***")
|
|
|
|
metrics = trainer.evaluate()
|
|
metrics["eval_samples"] =len(eval_dataset)
|
|
try:
|
|
perplexity = math.exp(metrics["eval_loss"])
|
|
except OverflowError:
|
|
perplexity = float("inf")
|
|
metrics["perplexity"] = perplexity
|
|
|
|
trainer.log_metrics("eval", metrics)
|
|
trainer.save_metrics("eval", metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |