预测使用gradio部署

This commit is contained in:
somunslotus 2024-07-15 17:14:15 +08:00
parent 46067e5ea3
commit 8121564b51
1 changed files with 40 additions and 12 deletions

View File

@ -1,4 +1,5 @@
import argparse import argparse
import json
import os import os
import zipfile import zipfile
import glob import glob
@ -14,6 +15,7 @@ import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement from dp.launching.report import Report, ReportSection, AutoReportElement
import sys import sys
from pathlib import Path from pathlib import Path
import gradio as gr
from dp.launching.typing import BaseModel, Field, Int,Optional from dp.launching.typing import BaseModel, Field, Int,Optional
from dp.launching.cli import to_runner, default_minimal_exception_handler from dp.launching.cli import to_runner, default_minimal_exception_handler
@ -314,15 +316,41 @@ def generate_report(save_path: Dict[str, str], output_dir: str) -> None:
report.save(output_dir) report.save(output_dir)
# def to_parser(): def zipdir(path, ziph):
# return to_runner( """
# GlobalOptions, Zip the contents of a directory, including all subdirectories.
# main, """
# version='0.1.0', # Iterate over all the files and directories
# exception_handler=default_minimal_exception_handler, for root, dirs, files in os.walk(path):
# ) for file in files:
# # Create the full filepath by combining root directory and file name
# full_path = os.path.join(root, file)
# if __name__ == '__main__': # Write the file to the zip archive
# to_parser()(sys.argv[1:]) ziph.write(full_path, os.path.relpath(full_path, path))
# def gradio_interface(data_path):
print("data path:", data_path)
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
print("model path:", model_path)
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "result")
print("output_dir:", output_dir)
predict_and_plot(model_path, data_path, output_dir)
output_zip_path = os.path.join(output_dir, 'output.zip')
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
zipdir(output_dir, zipf)
return output_zip_path
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.File(label="Data Path (.zip)", type="filepath"),
],
outputs=gr.File(label="Output Zip File"),
title="原子位置缺陷预测",
description="Upload a zip file with images and specify the model path and output directory for prediction."
)
iface.launch(server_name="0.0.0.0", server_port=7860, share=False)