290 lines
12 KiB
Python
290 lines
12 KiB
Python
import os
|
|
import torch
|
|
import argparse
|
|
from functools import partial
|
|
import sys
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from sat import mpu, get_args, get_tokenizer
|
|
from sat.training.deepspeed_training import training_main
|
|
from sat.helpers import print_rank0
|
|
from utils.models import FineTuneTrainCogAgentModel
|
|
from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor
|
|
|
|
def disable_untrainable_params(self):
|
|
total_trainable = 0
|
|
# enable = ['vit']
|
|
enable = ["encoder", "cross_attention", "linear_proj", 'mlp.vision', 'rotary.vision', 'eoi', 'boi', 'vit']
|
|
if self.args.use_ptuning:
|
|
enable.extend(['ptuning'])
|
|
if self.args.use_lora or self.args.use_qlora:
|
|
enable.extend(['matrix_A', 'matrix_B'])
|
|
for n, p in self.named_parameters():
|
|
flag = False
|
|
for e in enable:
|
|
if type(e) is tuple:
|
|
if e[0].lower() in n.lower() and e[1].lower() in n.lower() and 55 > int(n[:n.find('.mlp')].split('.')[-1]) > 45:
|
|
flag = True
|
|
break
|
|
else:
|
|
if e.lower() in n.lower():
|
|
flag = True
|
|
break
|
|
if not flag:
|
|
p.requires_grad_(False)
|
|
else:
|
|
total_trainable += p.numel()
|
|
if 'encoder' in n or 'vit' in n:
|
|
p.lr_scale = 0.1
|
|
print_rank0(n)
|
|
print_rank0("***** Total trainable parameters: "+str(total_trainable)+" *****")
|
|
|
|
FineTuneTrainCogAgentModel.disable_untrainable_params = disable_untrainable_params
|
|
|
|
def data_collator(examples, cross_image_processor=None):
|
|
def to_tensor(value):
|
|
"""Converts lists or numpy arrays to tensors."""
|
|
if isinstance(value, list):
|
|
return torch.tensor(value)
|
|
elif isinstance(value, np.ndarray):
|
|
return torch.from_numpy(value)
|
|
return value
|
|
|
|
def concatenate_tensors(attribute, key):
|
|
"""Concatenates tensors for a specific attribute and key."""
|
|
if attribute is None:
|
|
return torch.cat([ex[key] for ex in examples if isinstance(ex[key], torch.Tensor)])
|
|
else:
|
|
return torch.cat([ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)])
|
|
|
|
# Convert all lists and numpy arrays in examples to tensors
|
|
for example in examples:
|
|
for key, value in example.items():
|
|
example[key] = to_tensor(value)
|
|
|
|
# Extract and concatenate attributes from examples
|
|
img_args = {}
|
|
for attribute in ['vision', 'cross']:
|
|
if attribute == 'cross' and cross_image_processor is None:
|
|
continue
|
|
|
|
if attribute in examples[-1]: # Using the last example as reference
|
|
for key in examples[-1][attribute]:
|
|
tensor_key = f"{attribute}_{key}"
|
|
tensors_to_concatenate = [ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)]
|
|
if tensors_to_concatenate:
|
|
img_args[tensor_key] = concatenate_tensors(attribute, key)
|
|
else:
|
|
img_args[tensor_key] = examples[-1][attribute][key]
|
|
|
|
# Remove 'vision' and 'cross' keys from examples
|
|
for example in examples:
|
|
example.pop('vision', None)
|
|
example.pop('cross', None)
|
|
|
|
# Create model_args by concatenating tensors and copying other attributes
|
|
model_args = {key: concatenate_tensors(None, key)
|
|
if isinstance(examples[-1][key], torch.Tensor) else examples[-1][key]
|
|
for key in examples[-1]
|
|
}
|
|
|
|
# Merge img_args into model_args
|
|
model_args.update(img_args)
|
|
return model_args
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
def broadcast_auto(data_dict):
|
|
type2list = defaultdict(list)
|
|
other = []
|
|
for k in data_dict:
|
|
if type(data_dict[k]) is torch.Tensor:
|
|
type2list[data_dict[k].dtype].append(k)
|
|
else:
|
|
other.append(k)
|
|
new_data = {}
|
|
for k in type2list:
|
|
new_data.update(mpu.broadcast_data(type2list[k], data_dict, k))
|
|
for k in other:
|
|
new_data[k] = data_dict[k]
|
|
return new_data
|
|
|
|
def get_batch(data_iterator, args, timers):
|
|
# Broadcast data.
|
|
timers('data loader').start()
|
|
if data_iterator is not None:
|
|
data = next(data_iterator)
|
|
else:
|
|
data = None
|
|
timers('data loader').stop()
|
|
data_b = broadcast_auto(data)
|
|
for k in data_b:
|
|
if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long:
|
|
if args.fp16:
|
|
data_b[k] = data_b[k].half()
|
|
elif args.bf16:
|
|
data_b[k] = data_b[k].bfloat16()
|
|
return data_b
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
import numpy as np
|
|
|
|
from sat.model.mixins import CachedAutoregressiveMixin
|
|
from sat.generation.autoregressive_sampling import filling_sequence
|
|
from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy
|
|
|
|
|
|
def chat(model, tokenizer, tokens,
|
|
max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs):
|
|
inputs = tokens.to(model.parameters().__next__().device)[0]
|
|
seq = torch.cat(
|
|
[inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0
|
|
)
|
|
strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id])
|
|
# strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id],
|
|
# num_beams=num_beams, consider_end=True)
|
|
get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask'])
|
|
output = filling_sequence(
|
|
model, seq,
|
|
batch_size=1,
|
|
strategy=strategy,
|
|
get_masks_and_position_ids=get_func,
|
|
**kwargs
|
|
)[0] # drop memory
|
|
|
|
return output
|
|
|
|
|
|
def forward_step_eval(data_iterator, model, args, timers):
|
|
def compute_metrics(eval_preds):
|
|
preds, labels, device = eval_preds
|
|
preds = preds.unsqueeze(0)
|
|
if isinstance(preds, tuple):
|
|
preds = preds[0]
|
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
if args.ignore_pad_token_for_loss:
|
|
# Replace -100 in the labels as we can't decode them.
|
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
|
|
score_dict = {
|
|
"acc": [],
|
|
"acc_w/o_case": [],
|
|
}
|
|
for pred, label in zip(decoded_preds, decoded_labels):
|
|
if args.rank == 0:
|
|
print('pred', pred, 'label', label, flush=True)
|
|
if pred == label:
|
|
score_dict['acc'].append(1.)
|
|
else:
|
|
score_dict['acc'].append(0.)
|
|
if pred.lower() == label.lower():
|
|
score_dict['acc_w/o_case'].append(1.)
|
|
else:
|
|
score_dict['acc_w/o_case'].append(0.)
|
|
|
|
|
|
for k, v in score_dict.items():
|
|
score_dict[k] = float(np.mean(v))
|
|
return score_dict
|
|
|
|
# Get the batch.
|
|
timers('batch generator').start()
|
|
data_b = get_batch(
|
|
data_iterator, args, timers)
|
|
timers('batch generator').stop()
|
|
|
|
context_len = int(data_b['context_length'][0])
|
|
tokens = data_b['input_ids'][:, :context_len]
|
|
data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len]
|
|
data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len]
|
|
data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len]
|
|
|
|
data_b.pop('input_ids')
|
|
data_b.pop('attention_mask')
|
|
data_b.pop('position_ids')
|
|
labels = data_b.pop('labels')
|
|
qid = data_b.pop('question_id')
|
|
|
|
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
|
|
outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:]
|
|
# print(outputs)
|
|
model.del_mixin('auto-regressive')
|
|
|
|
return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in
|
|
compute_metrics(
|
|
(outputs.cpu(), labels.cpu(), outputs.device)).items()}
|
|
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
def forward_step(data_iterator, model, args, timers):
|
|
"""Forward step."""
|
|
|
|
# Get the batch.
|
|
timers('batch generator').start()
|
|
data_b = get_batch(
|
|
data_iterator, args, timers)
|
|
labels = data_b.pop('labels')
|
|
timers('batch generator').stop()
|
|
logits = model(**data_b)[0]
|
|
lm_logits = logits.to(torch.float32)
|
|
# Shift so that tokens < n predict n
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous()
|
|
# Flatten the tokens
|
|
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
loss = loss.to(torch.float32)
|
|
|
|
return loss, {'loss': loss}
|
|
|
|
from utils.utils import ItemDataset
|
|
def create_dataset_function(image_processor, text_processor, cross_image_processor, path, args):
|
|
dataset = ItemDataset(image_processor, text_processor, args, path, cross_image_processor=cross_image_processor)
|
|
return dataset
|
|
|
|
from sat.model.finetune.lora2 import LoraMixin
|
|
from sat.model.finetune.prompt_tuning import PTuningV2Mixin
|
|
|
|
if __name__ == '__main__':
|
|
py_parser = argparse.ArgumentParser(add_help=False)
|
|
py_parser.add_argument('--max_length', type=int)
|
|
py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false')
|
|
py_parser.add_argument("--version", type=str, default="chat", choices=["chat", "vqa"], help='version to interact with')
|
|
py_parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt')
|
|
py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
|
|
py_parser.add_argument("--vit_checkpoint_activations", action='store_true')
|
|
py_parser = FineTuneTrainCogAgentModel.add_model_specific_args(py_parser)
|
|
known, args_list = py_parser.parse_known_args()
|
|
args = get_args(args_list)
|
|
args = argparse.Namespace(**vars(args), **vars(known))
|
|
if args.use_qlora:
|
|
args.device = 'cpu'
|
|
|
|
model, args = FineTuneTrainCogAgentModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {})
|
|
if args.use_ptuning: # TODO: wait for SAT updating
|
|
model.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
|
|
|
|
if args.use_lora:
|
|
model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True)
|
|
model.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True)
|
|
elif args.use_qlora:
|
|
model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True)
|
|
|
|
if args.use_qlora and torch.cuda.is_available():
|
|
model = model.to('cuda')
|
|
from utils.utils import llama2_tokenizer
|
|
tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
|
|
image_processor = get_image_processor(args.eva_args["image_size"][0])
|
|
cross_image_processor = get_image_processor(args.cross_image_pix)
|
|
text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)
|
|
|
|
model = training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor, cross_image_processor), collate_fn=partial(data_collator, cross_image_processor=cross_image_processor), forward_step_eval=forward_step_eval)
|
|
if args.use_lora:
|
|
model.get_mixin("lora").merge_lora()
|
|
model.get_mixin("eva").vit_model.get_mixin("lora").merge_lora()
|
|
args.use_lora = False
|
|
args.save = "checkpoints/merged_lora_cogagent"
|
|
from sat.training.model_io import save_checkpoint
|
|
save_checkpoint(1, model, None, None, args) |