预测使用gradio部署

This commit is contained in:
somunslotus 2024-07-15 17:14:15 +08:00
parent 46067e5ea3
commit 8121564b51
1 changed files with 40 additions and 12 deletions

View File

@ -1,4 +1,5 @@
import argparse
import json
import os
import zipfile
import glob
@ -14,6 +15,7 @@ import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement
import sys
from pathlib import Path
import gradio as gr
from dp.launching.typing import BaseModel, Field, Int,Optional
from dp.launching.cli import to_runner, default_minimal_exception_handler
@ -314,15 +316,41 @@ def generate_report(save_path: Dict[str, str], output_dir: str) -> None:
report.save(output_dir)
# def to_parser():
# return to_runner(
# GlobalOptions,
# main,
# version='0.1.0',
# exception_handler=default_minimal_exception_handler,
# )
#
#
# if __name__ == '__main__':
# to_parser()(sys.argv[1:])
#
def zipdir(path, ziph):
"""
Zip the contents of a directory, including all subdirectories.
"""
# Iterate over all the files and directories
for root, dirs, files in os.walk(path):
for file in files:
# Create the full filepath by combining root directory and file name
full_path = os.path.join(root, file)
# Write the file to the zip archive
ziph.write(full_path, os.path.relpath(full_path, path))
def gradio_interface(data_path):
print("data path:", data_path)
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
print("model path:", model_path)
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "result")
print("output_dir:", output_dir)
predict_and_plot(model_path, data_path, output_dir)
output_zip_path = os.path.join(output_dir, 'output.zip')
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
zipdir(output_dir, zipf)
return output_zip_path
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.File(label="Data Path (.zip)", type="filepath"),
],
outputs=gr.File(label="Output Zip File"),
title="原子位置缺陷预测",
description="Upload a zip file with images and specify the model path and output directory for prediction."
)
iface.launch(server_name="0.0.0.0", server_port=7860, share=False)