227 lines
9.4 KiB
Python
227 lines
9.4 KiB
Python
import gradio as gr
|
|
import os, sys
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from PIL import Image
|
|
import base64
|
|
import json
|
|
import requests
|
|
import base64
|
|
import hashlib
|
|
import torch
|
|
import time
|
|
import re
|
|
import argparse
|
|
from sat.model.mixins import CachedAutoregressiveMixin
|
|
from sat.mpu import get_model_parallel_world_size
|
|
from sat.model import AutoModel
|
|
|
|
|
|
from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor, parse_response
|
|
from utils.models import CogAgentModel, CogVLMModel
|
|
|
|
|
|
|
|
DESCRIPTION = '''<h1 style='text-align: center'> <a href="https://github.com/THUDM/CogVLM">CogVLM / CogAgent</a> </h1>'''
|
|
|
|
NOTES = '<h3> This app is adapted from <a href="https://github.com/THUDM/CogVLM">https://github.com/THUDM/CogVLM</a>. It would be recommended to check out the repo if you want to see the detail of our model, CogVLM & CogAgent. </h3>'
|
|
|
|
MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.<br>Hint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
|
|
|
|
|
|
AGENT_NOTICE = 'Hint 1: To use <strong>Agent</strong> function, please use the <a href="https://github.com/THUDM/CogVLM/blob/main/utils/utils/template.py#L761">prompts for agents</a>.'
|
|
|
|
GROUNDING_NOTICE = 'Hint 2: To use <strong>Grounding</strong> function, please use the <a href="https://github.com/THUDM/CogVLM/blob/main/utils/utils/template.py#L344">prompts for grounding</a>.'
|
|
|
|
|
|
|
|
|
|
default_chatbox = [("", "Hi, What do you want to know about this image?")]
|
|
|
|
|
|
model = image_processor = text_processor_infer = None
|
|
|
|
is_grounding = False
|
|
|
|
def process_image_without_resize(image_prompt):
|
|
image = Image.open(image_prompt)
|
|
# print(f"height:{image.height}, width:{image.width}")
|
|
timestamp = int(time.time())
|
|
file_ext = os.path.splitext(image_prompt)[1]
|
|
filename_grounding = f"examples/{timestamp}_grounding{file_ext}"
|
|
return image, filename_grounding
|
|
|
|
from sat.quantization.kernels import quantize
|
|
|
|
def load_model(args):
|
|
model, model_args = AutoModel.from_pretrained(
|
|
args.from_pretrained,
|
|
args=argparse.Namespace(
|
|
deepspeed=None,
|
|
local_rank=0,
|
|
rank=0,
|
|
world_size=world_size,
|
|
model_parallel_size=world_size,
|
|
mode='inference',
|
|
fp16=args.fp16,
|
|
bf16=args.bf16,
|
|
skip_init=True,
|
|
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
|
|
device='cpu' if args.quant else 'cuda'),
|
|
overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {}
|
|
)
|
|
model = model.eval()
|
|
assert world_size == get_model_parallel_world_size(), "world size must equal to model parallel size for cli_demo!"
|
|
|
|
language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version
|
|
tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=language_processor_version)
|
|
image_processor = get_image_processor(model_args.eva_args["image_size"][0])
|
|
cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None
|
|
|
|
if args.quant:
|
|
quantize(model, args.quant)
|
|
if torch.cuda.is_available():
|
|
model = model.cuda()
|
|
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
|
|
|
|
text_processor_infer = llama2_text_processor_inference(tokenizer, args.max_length, model.image_length)
|
|
|
|
return model, image_processor, cross_image_processor, text_processor_infer
|
|
|
|
|
|
def post(
|
|
input_text,
|
|
temperature,
|
|
top_p,
|
|
top_k,
|
|
image_prompt,
|
|
result_previous,
|
|
hidden_image,
|
|
state
|
|
):
|
|
result_text = [(ele[0], ele[1]) for ele in result_previous]
|
|
for i in range(len(result_text)-1, -1, -1):
|
|
if result_text[i][0] == "" or result_text[i][0] == None:
|
|
del result_text[i]
|
|
print(f"history {result_text}")
|
|
|
|
global model, image_processor, cross_image_processor, text_processor_infer, is_grounding
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
pil_img, image_path_grounding = process_image_without_resize(image_prompt)
|
|
response, _, cache_image = chat(
|
|
image_path="",
|
|
model=model,
|
|
text_processor=text_processor_infer,
|
|
img_processor=image_processor,
|
|
query=input_text,
|
|
history=result_text,
|
|
cross_img_processor=cross_image_processor,
|
|
image=pil_img,
|
|
max_length=2048,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
invalid_slices=text_processor_infer.invalid_slices if hasattr(text_processor_infer, "invalid_slices") else [],
|
|
no_prompt=False,
|
|
args=state['args']
|
|
)
|
|
except Exception as e:
|
|
print("error message", e)
|
|
result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
|
|
return "", result_text, hidden_image
|
|
|
|
answer = response
|
|
if is_grounding:
|
|
parse_response(pil_img, answer, image_path_grounding)
|
|
new_answer = answer.replace(input_text, "")
|
|
result_text.append((input_text, new_answer))
|
|
result_text.append((None, (image_path_grounding,)))
|
|
else:
|
|
result_text.append((input_text, answer))
|
|
print(result_text)
|
|
print('finished')
|
|
return "", result_text, hidden_image
|
|
|
|
|
|
def clear_fn(value):
|
|
return "", default_chatbox, None
|
|
|
|
def clear_fn2(value):
|
|
return default_chatbox
|
|
|
|
|
|
def main(args):
|
|
global model, image_processor, cross_image_processor, text_processor_infer, is_grounding
|
|
model, image_processor, cross_image_processor, text_processor_infer = load_model(args)
|
|
is_grounding = 'grounding' in args.from_pretrained
|
|
|
|
gr.close_all()
|
|
|
|
with gr.Blocks(css='style.css') as demo:
|
|
state = gr.State({'args': args})
|
|
|
|
gr.Markdown(DESCRIPTION)
|
|
gr.Markdown(NOTES)
|
|
|
|
|
|
with gr.Row():
|
|
with gr.Column(scale=5):
|
|
with gr.Group():
|
|
gr.Markdown(AGENT_NOTICE)
|
|
gr.Markdown(GROUNDING_NOTICE)
|
|
input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
|
|
|
|
with gr.Row():
|
|
run_button = gr.Button('Generate')
|
|
clear_button = gr.Button('Clear')
|
|
|
|
image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
|
|
|
|
with gr.Row():
|
|
temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
|
|
top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
|
|
top_k = gr.Slider(maximum=100, value=10, minimum=1, step=1, label='Top K')
|
|
|
|
with gr.Column(scale=5):
|
|
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")], height=600)
|
|
hidden_image_hash = gr.Textbox(visible=False)
|
|
|
|
|
|
gr.Markdown(MAINTENANCE_NOTICE1)
|
|
|
|
print(gr.__version__)
|
|
run_button.click(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt, result_text, hidden_image_hash, state],
|
|
outputs=[input_text, result_text, hidden_image_hash])
|
|
input_text.submit(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt, result_text, hidden_image_hash, state],
|
|
outputs=[input_text, result_text, hidden_image_hash])
|
|
clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
|
|
image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
|
|
image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
|
|
|
|
|
|
# demo.queue(concurrency_count=10)
|
|
demo.launch()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
|
|
parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
|
|
parser.add_argument("--top_k", type=int, default=1, help='top k for top k sampling')
|
|
parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
|
|
parser.add_argument("--version", type=str, default="chat", choices=['chat', 'vqa', 'chat_old', 'base'], help='version of language process. if there is \"text_processor_version\" in model_config.json, this option will be overwritten')
|
|
parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
|
|
parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt')
|
|
parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
|
|
parser.add_argument("--fp16", action="store_true")
|
|
parser.add_argument("--bf16", action="store_true")
|
|
parser.add_argument("--stream_chat", action="store_true")
|
|
args = parser.parse_args()
|
|
rank = int(os.environ.get('RANK', 0))
|
|
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
|
args = parser.parse_args()
|
|
main(args)
|