This commit is contained in:
baocheny 2024-01-30 10:47:34 +08:00 committed by JP
parent f8d47c3eb7
commit a63fd30eda
1 changed files with 3 additions and 2 deletions

View File

@ -48,7 +48,7 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql
eos_token_id=eos_token_id,
pad_token_id=eos_token_id,
)[0]["generated_text"]
.split("```sql")[-1]
.split("# SQL\n")[-1]
.split("```")[0]
.split(";")[0]
.strip()
@ -58,8 +58,9 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql
if __name__ == "__main__":
# Parse arguments
_default_question="Do we get more revenue from customers in New York compared to customers in San Francisco? Give me the total revenue for each city, and the difference between the two."
parser = argparse.ArgumentParser(description="Run inference on a question")
parser.add_argument("-q","--question", type=str, help="Question to run inference on")
parser.add_argument("-q","--question", type=str, default=_default_question, help="Question to run inference on")
args = parser.parse_args()
question = args.question
print("Loading a model and generating a SQL query for answering your question...")