diff --git a/inference.py b/inference.py index 05fa2cb..2d6c4c1 100644 --- a/inference.py +++ b/inference.py @@ -48,7 +48,7 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql eos_token_id=eos_token_id, pad_token_id=eos_token_id, )[0]["generated_text"] - .split("```sql")[-1] + .split("# SQL\n")[-1] .split("```")[0] .split(";")[0] .strip() @@ -58,8 +58,9 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql if __name__ == "__main__": # Parse arguments + _default_question="Do we get more revenue from customers in New York compared to customers in San Francisco? Give me the total revenue for each city, and the difference between the two." parser = argparse.ArgumentParser(description="Run inference on a question") - parser.add_argument("-q","--question", type=str, help="Question to run inference on") + parser.add_argument("-q","--question", type=str, default=_default_question, help="Question to run inference on") args = parser.parse_args() question = args.question print("Loading a model and generating a SQL query for answering your question...")